Unverified Commit 00031785 authored by demSd's avatar demSd Committed by GitHub
Browse files

BartForCausalLM analogs to `ProphetNetForCausalLM` (#9128)



* initiliaze bart4causalLM

* create BartDecoderWrapper, setters/getters

* delete spaces

* forward and additional methods

* update cache function, loss function, remove ngram* params in data class.

* add bartcausallm, bartdecoder testing

* correct bart for causal lm

* remove at

* add mbart as well

* up

* fix typo

* up

* correct

* add pegasusforcausallm

* add blenderbotforcausallm

* add blenderbotsmallforcausallm

* add marianforcausallm

* add test for MarianForCausalLM

* add Pegasus test

* add BlenderbotSmall test

* add blenderbot test

* fix a fail

* fix an import fail

* a fix

* fix

* Update modeling_pegasus.py

* fix models

* fix inputs_embeds setting getter

* adapt tests

* correct repo utils check

* finish test improvement

* fix tf models as well

* make style

* make fix-copies

* fix copies

* run all tests

* last changes

* fix all tests
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7898fc03
...@@ -39,6 +39,7 @@ if is_tokenizers_available(): ...@@ -39,6 +39,7 @@ if is_tokenizers_available():
if is_torch_available(): if is_torch_available():
_import_structure["modeling_mbart"] = [ _import_structure["modeling_mbart"] = [
"MBART_PRETRAINED_MODEL_ARCHIVE_LIST", "MBART_PRETRAINED_MODEL_ARCHIVE_LIST",
"MBartForCausalLM",
"MBartForConditionalGeneration", "MBartForConditionalGeneration",
"MBartForQuestionAnswering", "MBartForQuestionAnswering",
"MBartForSequenceClassification", "MBartForSequenceClassification",
...@@ -62,6 +63,7 @@ if TYPE_CHECKING: ...@@ -62,6 +63,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_mbart import ( from .modeling_mbart import (
MBART_PRETRAINED_MODEL_ARCHIVE_LIST, MBART_PRETRAINED_MODEL_ARCHIVE_LIST,
MBartForCausalLM,
MBartForConditionalGeneration, MBartForConditionalGeneration,
MBartForQuestionAnswering, MBartForQuestionAnswering,
MBartForSequenceClassification, MBartForSequenceClassification,
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch MBART model. """ """ PyTorch MBART model. """
import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -36,6 +35,7 @@ from ...file_utils import ( ...@@ -36,6 +35,7 @@ from ...file_utils import (
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput, Seq2SeqQuestionAnsweringModelOutput,
...@@ -852,6 +852,31 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -852,6 +852,31 @@ class MBartDecoder(MBartPreTrainedModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -954,19 +979,9 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -954,19 +979,9 @@ class MBartDecoder(MBartPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask attention_mask = self._prepare_decoder_attention_mask(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = None )
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
...@@ -984,7 +999,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -984,7 +999,7 @@ class MBartDecoder(MBartPreTrainedModel):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
...@@ -1021,7 +1036,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1021,7 +1036,7 @@ class MBartDecoder(MBartPreTrainedModel):
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -1032,7 +1047,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1032,7 +1047,7 @@ class MBartDecoder(MBartPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
...@@ -1048,7 +1063,9 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1048,7 +1063,9 @@ class MBartDecoder(MBartPreTrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
all_cross_attentions += (layer_outputs[2],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
...@@ -1570,3 +1587,210 @@ class MBartForQuestionAnswering(MBartPreTrainedModel): ...@@ -1570,3 +1587,210 @@ class MBartForQuestionAnswering(MBartPreTrainedModel):
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->MBart
class MBartDecoderWrapper(MBartPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = MBartDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart
class MBartForCausalLM(MBartPreTrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
self.model = MBartDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
config.vocab_size]``.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
Returns:
Example::
>>> from transformers import MBartTokenizer, MBartForCausalLM
>>> tokenizer = MBartTokenizer.from_pretrained('facebook/bart-large')
>>> model = MBartForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
...@@ -938,7 +938,7 @@ class TFMBartDecoder(tf.keras.layers.Layer): ...@@ -938,7 +938,7 @@ class TFMBartDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1] inputs["attention_mask"], tgt_len=input_shape[-1]
) )
......
...@@ -42,6 +42,7 @@ if is_torch_available(): ...@@ -42,6 +42,7 @@ if is_torch_available():
"PegasusForConditionalGeneration", "PegasusForConditionalGeneration",
"PegasusModel", "PegasusModel",
"PegasusPreTrainedModel", "PegasusPreTrainedModel",
"PegasusForCausalLM",
] ]
if is_tf_available(): if is_tf_available():
...@@ -60,6 +61,7 @@ if TYPE_CHECKING: ...@@ -60,6 +61,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_pegasus import ( from .modeling_pegasus import (
PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST, PEGASUS_PRETRAINED_MODEL_ARCHIVE_LIST,
PegasusForCausalLM,
PegasusForConditionalGeneration, PegasusForConditionalGeneration,
PegasusModel, PegasusModel,
PegasusPreTrainedModel, PegasusPreTrainedModel,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch PEGASUS model. """ """ PyTorch PEGASUS model. """
import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -36,6 +36,7 @@ from ...file_utils import ( ...@@ -36,6 +36,7 @@ from ...file_utils import (
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
...@@ -817,6 +818,31 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -817,6 +818,31 @@ class PegasusDecoder(PegasusPreTrainedModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -919,19 +945,9 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -919,19 +945,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask attention_mask = self._prepare_decoder_attention_mask(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = None )
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
...@@ -948,7 +964,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -948,7 +964,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
...@@ -985,7 +1001,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -985,7 +1001,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -996,7 +1012,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -996,7 +1012,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
...@@ -1012,7 +1028,9 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1012,7 +1028,9 @@ class PegasusDecoder(PegasusPreTrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
all_cross_attentions += (layer_outputs[2],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
...@@ -1325,3 +1343,210 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ...@@ -1325,3 +1343,210 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
) )
return reordered_past return reordered_past
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Pegasus
class PegasusDecoderWrapper(PegasusPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = PegasusDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Pegasus
class PegasusForCausalLM(PegasusPreTrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
self.model = PegasusDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
config.vocab_size]``.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
Returns:
Example::
>>> from transformers import PegasusTokenizer, PegasusForCausalLM
>>> tokenizer = PegasusTokenizer.from_pretrained('facebook/bart-large')
>>> model = PegasusForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
...@@ -955,7 +955,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer): ...@@ -955,7 +955,7 @@ class TFPegasusDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1] inputs["attention_mask"], tgt_len=input_shape[-1]
) )
......
...@@ -1330,7 +1330,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1330,7 +1330,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
>>> import torch >>> import torch
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased') >>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
>>> model = ProphetNetDecoder.from_pretrained('patrickvonplaten/prophetnet-large-uncased-standalone', add_cross_attention=False) >>> model = ProphetNetDecoder.from_pretrained('microsoft/prophetnet-large-uncased', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs) >>> outputs = model(**inputs)
......
...@@ -431,6 +431,11 @@ class AutoModelWithLMHead: ...@@ -431,6 +431,11 @@ class AutoModelWithLMHead:
BART_PRETRAINED_MODEL_ARCHIVE_LIST = None BART_PRETRAINED_MODEL_ARCHIVE_LIST = None
class BartForCausalLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class BartForConditionalGeneration: class BartForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -596,6 +601,11 @@ def load_tf_weights_in_bert_generation(*args, **kwargs): ...@@ -596,6 +601,11 @@ def load_tf_weights_in_bert_generation(*args, **kwargs):
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
class BlenderbotForCausalLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class BlenderbotForConditionalGeneration: class BlenderbotForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -617,6 +627,11 @@ class BlenderbotModel: ...@@ -617,6 +627,11 @@ class BlenderbotModel:
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
class BlenderbotSmallForCausalLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class BlenderbotSmallForConditionalGeneration: class BlenderbotSmallForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -1464,6 +1479,11 @@ class LxmertXLayer: ...@@ -1464,6 +1479,11 @@ class LxmertXLayer:
requires_pytorch(self) requires_pytorch(self)
class MarianForCausalLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class MarianModel: class MarianModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -1482,6 +1502,11 @@ class MarianMTModel: ...@@ -1482,6 +1502,11 @@ class MarianMTModel:
requires_pytorch(self) requires_pytorch(self)
class MBartForCausalLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class MBartForConditionalGeneration: class MBartForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
...@@ -1772,6 +1797,11 @@ def load_tf_weights_in_openai_gpt(*args, **kwargs): ...@@ -1772,6 +1797,11 @@ def load_tf_weights_in_openai_gpt(*args, **kwargs):
requires_pytorch(load_tf_weights_in_openai_gpt) requires_pytorch(load_tf_weights_in_openai_gpt)
class PegasusForCausalLM:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class PegasusForConditionalGeneration: class PegasusForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_pytorch(self)
......
...@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -488,7 +488,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor, floats_tensor
if is_torch_available(): if is_torch_available():
...@@ -847,4 +847,220 @@ class {{cookiecutter.camelcase_modelname}}ModelIntegrationTests(unittest.TestCas ...@@ -847,4 +847,220 @@ class {{cookiecutter.camelcase_modelname}}ModelIntegrationTests(unittest.TestCas
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
) )
assert generated == EXPECTED assert generated == EXPECTED
class {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
d_model=16,
decoder_seq_length=7,
is_training=True,
is_decoder=True,
use_attention_mask=True,
use_cache=False,
use_labels=True,
decoder_start_token_id=2,
decoder_ffn_dim=32,
decoder_layers=4,
encoder_attention_heads=4,
decoder_attention_heads=4,
max_position_embeddings=30,
is_encoder_decoder=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = decoder_layers
self.decoder_layers = decoder_layers
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.num_attention_heads = decoder_attention_heads
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.decoder_key_length = decoder_seq_length
self.base_model_out_len = 2
self.decoder_attention_idx = 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = {{cookiecutter.camelcase_modelname}}Config(
vocab_size=self.vocab_size,
d_model=self.d_model,
decoder_layers=self.decoder_layers,
decoder_ffn_dim=self.decoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
decoder_attention_heads=self.decoder_attention_heads,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
use_cache=self.use_cache,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
lm_labels,
)
def create_and_check_decoder_model_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
config.use_cache = True
model = {{cookiecutter.camelcase_modelname}}Decoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
model = {{cookiecutter.camelcase_modelname}}Decoder(config=config).to(torch_device).eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = ({{cookiecutter.camelcase_modelname}}Decoder, {{cookiecutter.camelcase_modelname}}ForCausalLM) if is_torch_available() else ()
all_generative_model_classes = ({{cookiecutter.camelcase_modelname}}ForCausalLM,) if is_torch_available() else ()
test_pruning = False
is_encoder_decoder = False
def setUp(
self,
):
self.model_tester = {{cookiecutter.camelcase_modelname}}StandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class={{cookiecutter.camelcase_modelname}}Config)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_attn_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
{% endif -%} {% endif -%}
...@@ -27,7 +27,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers ...@@ -27,7 +27,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_generation_utils import GenerationTesterMixin from .test_generation_utils import GenerationTesterMixin
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
...@@ -36,6 +36,7 @@ if is_torch_available(): ...@@ -36,6 +36,7 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
BartConfig, BartConfig,
BartForCausalLM,
BartForConditionalGeneration, BartForConditionalGeneration,
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
...@@ -178,7 +179,7 @@ class BartModelTester: ...@@ -178,7 +179,7 @@ class BartModelTester:
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = BartModel(config=config).to(torch_device).eval() model = BartModel(config=config).to(torch_device).eval()
...@@ -719,3 +720,242 @@ class BartModelIntegrationTests(unittest.TestCase): ...@@ -719,3 +720,242 @@ class BartModelIntegrationTests(unittest.TestCase):
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
) )
assert generated_summaries == EXPECTED assert generated_summaries == EXPECTED
class BartStandaloneDecoderModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
d_model=16,
decoder_seq_length=7,
is_training=True,
is_decoder=True,
use_attention_mask=True,
use_cache=False,
use_labels=True,
decoder_start_token_id=2,
decoder_ffn_dim=32,
decoder_layers=4,
encoder_attention_heads=4,
decoder_attention_heads=4,
max_position_embeddings=30,
is_encoder_decoder=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = decoder_layers
self.decoder_layers = decoder_layers
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.num_attention_heads = decoder_attention_heads
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.decoder_key_length = decoder_seq_length
self.base_model_out_len = 2
self.decoder_attention_idx = 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = BartConfig(
vocab_size=self.vocab_size,
d_model=self.d_model,
encoder_layers=self.decoder_layers,
decoder_layers=self.decoder_layers,
decoder_ffn_dim=self.decoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
decoder_attention_heads=self.decoder_attention_heads,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
use_cache=self.use_cache,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
lm_labels,
)
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
attention_mask,
lm_labels,
) = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.decoder_seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
return (
config,
input_ids,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
lm_labels,
)
def create_and_check_decoder_model_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
config.use_cache = True
model = BartDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
model = BartDecoder(config=config).to(torch_device).eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BartDecoder, BartForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (BartForCausalLM,) if is_torch_available() else ()
test_pruning = False
is_encoder_decoder = False
def setUp(
self,
):
self.model_tester = BartStandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class=BartConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_attn_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Blenderbot model. """ """ Testing suite for the PyTorch Blenderbot model. """
import tempfile import tempfile
import unittest import unittest
...@@ -31,7 +30,11 @@ if is_torch_available(): ...@@ -31,7 +30,11 @@ if is_torch_available():
import torch import torch
from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration, BlenderbotModel, BlenderbotTokenizer from transformers import BlenderbotConfig, BlenderbotForConditionalGeneration, BlenderbotModel, BlenderbotTokenizer
from transformers.models.blenderbot.modeling_blenderbot import BlenderbotDecoder, BlenderbotEncoder from transformers.models.blenderbot.modeling_blenderbot import (
BlenderbotDecoder,
BlenderbotEncoder,
BlenderbotForCausalLM,
)
def prepare_blenderbot_inputs_dict( def prepare_blenderbot_inputs_dict(
...@@ -165,7 +168,7 @@ class BlenderbotModelTester: ...@@ -165,7 +168,7 @@ class BlenderbotModelTester:
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = BlenderbotModel(config=config).to(torch_device).eval() model = BlenderbotModel(config=config).to(torch_device).eval()
...@@ -300,3 +303,222 @@ class Blenderbot3BIntegrationTests(unittest.TestCase): ...@@ -300,3 +303,222 @@ class Blenderbot3BIntegrationTests(unittest.TestCase):
assert "I think it's because we are so worried about what people think of us." == reply.strip() assert "I think it's because we are so worried about what people think of us." == reply.strip()
del model del model
class BlenderbotStandaloneDecoderModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
d_model=16,
decoder_seq_length=7,
is_training=True,
is_decoder=True,
use_attention_mask=True,
use_cache=False,
use_labels=True,
decoder_start_token_id=2,
decoder_ffn_dim=32,
decoder_layers=4,
encoder_attention_heads=4,
decoder_attention_heads=4,
max_position_embeddings=30,
is_encoder_decoder=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = decoder_layers
self.decoder_layers = decoder_layers
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.num_attention_heads = decoder_attention_heads
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.decoder_key_length = decoder_seq_length
self.base_model_out_len = 2
self.decoder_attention_idx = 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = BlenderbotConfig(
vocab_size=self.vocab_size,
d_model=self.d_model,
decoder_layers=self.decoder_layers,
decoder_ffn_dim=self.decoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
decoder_attention_heads=self.decoder_attention_heads,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
use_cache=self.use_cache,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
lm_labels,
)
def create_and_check_decoder_model_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
config.use_cache = True
model = BlenderbotDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
model = BlenderbotDecoder(config=config).to(torch_device).eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
# past_key_values = model(input_ids, use_cache=True)["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class BlenderbotStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BlenderbotDecoder, BlenderbotForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotForCausalLM,) if is_torch_available() else ()
test_pruning = False
is_encoder_decoder = False
def setUp(
self,
):
self.model_tester = BlenderbotStandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class=BlenderbotConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_attn_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch BlenderbotSmall model. """ """ Testing suite for the PyTorch BlenderbotSmall model. """
import tempfile import tempfile
import unittest import unittest
...@@ -39,6 +38,7 @@ if is_torch_available(): ...@@ -39,6 +38,7 @@ if is_torch_available():
from transformers.models.blenderbot_small.modeling_blenderbot_small import ( from transformers.models.blenderbot_small.modeling_blenderbot_small import (
BlenderbotSmallDecoder, BlenderbotSmallDecoder,
BlenderbotSmallEncoder, BlenderbotSmallEncoder,
BlenderbotSmallForCausalLM,
) )
...@@ -173,7 +173,7 @@ class BlenderbotSmallModelTester: ...@@ -173,7 +173,7 @@ class BlenderbotSmallModelTester:
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = BlenderbotSmallModel(config=config).to(torch_device).eval() model = BlenderbotSmallModel(config=config).to(torch_device).eval()
...@@ -317,3 +317,221 @@ class Blenderbot90MIntegrationTests(unittest.TestCase): ...@@ -317,3 +317,221 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
"have you ever been to a sam club? it's a great club in the south.", "have you ever been to a sam club? it's a great club in the south.",
"have you ever heard of sam harris? he's an american singer, songwriter, and actor.", "have you ever heard of sam harris? he's an american singer, songwriter, and actor.",
) )
class BlenderbotSmallStandaloneDecoderModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
d_model=16,
decoder_seq_length=7,
is_training=True,
is_decoder=True,
use_attention_mask=True,
use_cache=False,
use_labels=True,
decoder_start_token_id=2,
decoder_ffn_dim=32,
decoder_layers=4,
encoder_attention_heads=4,
decoder_attention_heads=4,
max_position_embeddings=30,
is_encoder_decoder=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = decoder_layers
self.decoder_layers = decoder_layers
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.num_attention_heads = decoder_attention_heads
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.decoder_key_length = decoder_seq_length
self.base_model_out_len = 2
self.decoder_attention_idx = 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = BlenderbotSmallConfig(
vocab_size=self.vocab_size,
d_model=self.d_model,
decoder_layers=self.decoder_layers,
decoder_ffn_dim=self.decoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
decoder_attention_heads=self.decoder_attention_heads,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
use_cache=self.use_cache,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
lm_labels,
)
def create_and_check_decoder_model_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
config.use_cache = True
model = BlenderbotSmallDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
model = BlenderbotSmallDecoder(config=config).to(torch_device).eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class BlenderbotSmallStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (BlenderbotSmallDecoder, BlenderbotSmallForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (BlenderbotSmallForCausalLM,) if is_torch_available() else ()
test_pruning = False
is_encoder_decoder = False
def setUp(
self,
):
self.model_tester = BlenderbotSmallStandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class=BlenderbotSmallConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_attn_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
...@@ -20,6 +20,7 @@ import unittest ...@@ -20,6 +20,7 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bart import BartStandaloneDecoderModelTester
from .test_modeling_bert import BertModelTester from .test_modeling_bert import BertModelTester
from .test_modeling_bert_generation import BertGenerationEncoderTester from .test_modeling_bert_generation import BertGenerationEncoderTester
from .test_modeling_common import ids_tensor from .test_modeling_common import ids_tensor
...@@ -34,6 +35,7 @@ if is_torch_available(): ...@@ -34,6 +35,7 @@ if is_torch_available():
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
BartForCausalLM,
BertGenerationDecoder, BertGenerationDecoder,
BertGenerationEncoder, BertGenerationEncoder,
BertLMHeadModel, BertLMHeadModel,
...@@ -828,3 +830,57 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -828,3 +830,57 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def test_encoder_decoder_model_shared_weights(self): def test_encoder_decoder_model_shared_weights(self):
pass pass
@require_torch
class BartEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertModel(config)
decoder_model = BartForCausalLM(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = BertModelTester(self, batch_size=13)
model_tester_decoder = BartStandaloneDecoderModelTester(
self, batch_size=13, d_model=32, max_position_embeddings=512
)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_attention_mask,
encoder_hidden_states,
encoder_attention_mask,
lm_labels,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"encoder_hidden_states": encoder_hidden_states,
"labels": lm_labels,
}
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large")
def test_encoder_decoder_model_shared_weights(self):
pass
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch Marian model. """ """ Testing suite for the PyTorch Marian model. """
import tempfile import tempfile
import unittest import unittest
...@@ -45,7 +44,12 @@ if is_torch_available(): ...@@ -45,7 +44,12 @@ if is_torch_available():
convert_hf_name_to_opus_name, convert_hf_name_to_opus_name,
convert_opus_name_to_hf_name, convert_opus_name_to_hf_name,
) )
from transformers.models.marian.modeling_marian import MarianDecoder, MarianEncoder, shift_tokens_right from transformers.models.marian.modeling_marian import (
MarianDecoder,
MarianEncoder,
MarianForCausalLM,
shift_tokens_right,
)
def prepare_marian_inputs_dict( def prepare_marian_inputs_dict(
...@@ -182,7 +186,7 @@ class MarianModelTester: ...@@ -182,7 +186,7 @@ class MarianModelTester:
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = MarianModel(config=config).to(torch_device).eval() model = MarianModel(config=config).to(torch_device).eval()
...@@ -546,3 +550,221 @@ class TestConversionUtils(unittest.TestCase): ...@@ -546,3 +550,221 @@ class TestConversionUtils(unittest.TestCase):
"en-de", "en-de",
] ]
self.assertListEqual(expected_opus_names, converted_opus_names) self.assertListEqual(expected_opus_names, converted_opus_names)
class MarianStandaloneDecoderModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
d_model=16,
decoder_seq_length=7,
is_training=True,
is_decoder=True,
use_attention_mask=True,
use_cache=False,
use_labels=True,
decoder_start_token_id=2,
decoder_ffn_dim=32,
decoder_layers=4,
encoder_attention_heads=4,
decoder_attention_heads=4,
max_position_embeddings=30,
is_encoder_decoder=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = decoder_layers
self.decoder_layers = decoder_layers
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.num_attention_heads = decoder_attention_heads
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.decoder_key_length = decoder_seq_length
self.base_model_out_len = 2
self.decoder_attention_idx = 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = MarianConfig(
vocab_size=self.vocab_size,
d_model=self.d_model,
decoder_layers=self.decoder_layers,
decoder_ffn_dim=self.decoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
decoder_attention_heads=self.decoder_attention_heads,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
use_cache=self.use_cache,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
lm_labels,
)
def create_and_check_decoder_model_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
config.use_cache = True
model = MarianDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
model = MarianDecoder(config=config).to(torch_device).eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class MarianStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (MarianDecoder, MarianForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (MarianForCausalLM,) if is_torch_available() else ()
test_pruning = False
is_encoder_decoder = False
def setUp(
self,
):
self.model_tester = MarianStandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class=MarianConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_attn_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
...@@ -35,6 +35,7 @@ if is_torch_available(): ...@@ -35,6 +35,7 @@ if is_torch_available():
AutoTokenizer, AutoTokenizer,
BatchEncoding, BatchEncoding,
MBartConfig, MBartConfig,
MBartForCausalLM,
MBartForConditionalGeneration, MBartForConditionalGeneration,
MBartForQuestionAnswering, MBartForQuestionAnswering,
MBartForSequenceClassification, MBartForSequenceClassification,
...@@ -174,7 +175,7 @@ class MBartModelTester: ...@@ -174,7 +175,7 @@ class MBartModelTester:
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = MBartModel(config=config).to(torch_device).eval() model = MBartModel(config=config).to(torch_device).eval()
...@@ -431,3 +432,221 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -431,3 +432,221 @@ class MBartCC25IntegrationTest(AbstractSeq2SeqIntegrationTest):
outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True outputs, clean_up_tokenization_spaces=True, skip_special_tokens=True
)[0] )[0]
self.assertEqual(prediction, "of the best books I ever read!") self.assertEqual(prediction, "of the best books I ever read!")
class MBartStandaloneDecoderModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
d_model=16,
decoder_seq_length=7,
is_training=True,
is_decoder=True,
use_attention_mask=True,
use_cache=False,
use_labels=True,
decoder_start_token_id=2,
decoder_ffn_dim=32,
decoder_layers=4,
encoder_attention_heads=4,
decoder_attention_heads=4,
max_position_embeddings=30,
is_encoder_decoder=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = decoder_layers
self.decoder_layers = decoder_layers
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.num_attention_heads = decoder_attention_heads
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.decoder_key_length = decoder_seq_length
self.base_model_out_len = 2
self.decoder_attention_idx = 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = MBartConfig(
vocab_size=self.vocab_size,
d_model=self.d_model,
decoder_layers=self.decoder_layers,
decoder_ffn_dim=self.decoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
decoder_attention_heads=self.decoder_attention_heads,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
use_cache=self.use_cache,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
lm_labels,
)
def create_and_check_decoder_model_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
config.use_cache = True
model = MBartDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
model = MBartDecoder(config=config).to(torch_device).eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class MBartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (MBartDecoder, MBartForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (MBartForCausalLM,) if is_torch_available() else ()
test_pruning = False
is_encoder_decoder = False
def setUp(
self,
):
self.model_tester = MBartStandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class=MBartConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_attn_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Testing suite for the PyTorch PEGASUS model. """ """ Testing suite for the PyTorch PEGASUS model. """
import tempfile import tempfile
import unittest import unittest
...@@ -32,7 +31,7 @@ if is_torch_available(): ...@@ -32,7 +31,7 @@ if is_torch_available():
import torch import torch
from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel from transformers import AutoModelForSeq2SeqLM, PegasusConfig, PegasusForConditionalGeneration, PegasusModel
from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder from transformers.models.pegasus.modeling_pegasus import PegasusDecoder, PegasusEncoder, PegasusForCausalLM
def prepare_pegasus_inputs_dict( def prepare_pegasus_inputs_dict(
...@@ -166,7 +165,7 @@ class PegasusModelTester: ...@@ -166,7 +165,7 @@ class PegasusModelTester:
self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1]) self.parent.assertTrue(output_from_past_slice.shape[1] == next_tokens.shape[1])
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-2)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def check_encoder_decoder_model_standalone(self, config, inputs_dict): def check_encoder_decoder_model_standalone(self, config, inputs_dict):
model = PegasusModel(config=config).to(torch_device).eval() model = PegasusModel(config=config).to(torch_device).eval()
...@@ -308,3 +307,221 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -308,3 +307,221 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
"California's largest electricity provider has begun", "California's largest electricity provider has begun",
"N-Dubz have revealed they were", "N-Dubz have revealed they were",
] ]
class PegasusStandaloneDecoderModelTester:
def __init__(
self,
parent,
vocab_size=99,
batch_size=13,
d_model=16,
decoder_seq_length=7,
is_training=True,
is_decoder=True,
use_attention_mask=True,
use_cache=False,
use_labels=True,
decoder_start_token_id=2,
decoder_ffn_dim=32,
decoder_layers=4,
encoder_attention_heads=4,
decoder_attention_heads=4,
max_position_embeddings=30,
is_encoder_decoder=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.decoder_seq_length = decoder_seq_length
# For common tests
self.seq_length = self.decoder_seq_length
self.is_training = is_training
self.use_attention_mask = use_attention_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.d_model = d_model
self.hidden_size = d_model
self.num_hidden_layers = decoder_layers
self.decoder_layers = decoder_layers
self.decoder_ffn_dim = decoder_ffn_dim
self.encoder_attention_heads = encoder_attention_heads
self.decoder_attention_heads = decoder_attention_heads
self.num_attention_heads = decoder_attention_heads
self.eos_token_id = eos_token_id
self.bos_token_id = bos_token_id
self.pad_token_id = pad_token_id
self.decoder_start_token_id = decoder_start_token_id
self.use_cache = use_cache
self.max_position_embeddings = max_position_embeddings
self.is_encoder_decoder = is_encoder_decoder
self.scope = None
self.decoder_key_length = decoder_seq_length
self.base_model_out_len = 2
self.decoder_attention_idx = 1
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
attention_mask = None
if self.use_attention_mask:
attention_mask = ids_tensor([self.batch_size, self.decoder_seq_length], vocab_size=2)
lm_labels = None
if self.use_labels:
lm_labels = ids_tensor([self.batch_size, self.decoder_seq_length], self.vocab_size)
config = PegasusConfig(
vocab_size=self.vocab_size,
d_model=self.d_model,
decoder_layers=self.decoder_layers,
decoder_ffn_dim=self.decoder_ffn_dim,
encoder_attention_heads=self.encoder_attention_heads,
decoder_attention_heads=self.decoder_attention_heads,
eos_token_id=self.eos_token_id,
bos_token_id=self.bos_token_id,
use_cache=self.use_cache,
pad_token_id=self.pad_token_id,
decoder_start_token_id=self.decoder_start_token_id,
max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder,
)
return (
config,
input_ids,
attention_mask,
lm_labels,
)
def create_and_check_decoder_model_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
config.use_cache = True
model = PegasusDecoder(config=config).to(torch_device).eval()
# first forward pass
outputs = model(input_ids, use_cache=True)
outputs_use_cache_conf = model(input_ids)
outputs_no_past = model(input_ids, use_cache=False)
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
past_key_values = outputs["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past = model(next_input_ids)["last_hidden_state"]
output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def create_and_check_decoder_model_attention_mask_past(
self,
config,
input_ids,
attention_mask,
lm_labels,
):
model = PegasusDecoder(config=config).to(torch_device).eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
past_key_values = model(input_ids, attention_mask=attn_mask, use_cache=True)["past_key_values"]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)],
dim=1,
)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, next_input_ids.shape[-1] - 1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
assert torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
attention_mask,
lm_labels,
) = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
return config, inputs_dict
@require_torch
class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else ()
all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else ()
test_pruning = False
is_encoder_decoder = False
def setUp(
self,
):
self.model_tester = PegasusStandaloneDecoderModelTester(self, is_training=False)
self.config_tester = ConfigTester(self, config_class=PegasusConfig)
def test_config(self):
self.config_tester.run_common_tests()
def test_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past(*config_and_inputs)
def test_decoder_model_attn_mask_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
def test_retain_grad_hidden_states_attentions(self):
# decoder cannot keep gradients
return
...@@ -32,17 +32,17 @@ IGNORE_NON_TESTED = [ ...@@ -32,17 +32,17 @@ IGNORE_NON_TESTED = [
# models to ignore for not tested # models to ignore for not tested
"LEDEncoder", # Building part of bigger (tested) model. "LEDEncoder", # Building part of bigger (tested) model.
"LEDDecoder", # Building part of bigger (tested) model. "LEDDecoder", # Building part of bigger (tested) model.
"BartDecoder", # Building part of bigger (tested) model. "BartDecoderWrapper", # Building part of bigger (tested) model.
"BartEncoder", # Building part of bigger (tested) model. "BartEncoder", # Building part of bigger (tested) model.
"BertLMHeadModel", # Needs to be setup as decoder. "BertLMHeadModel", # Needs to be setup as decoder.
"BlenderbotSmallEncoder", # Building part of bigger (tested) model. "BlenderbotSmallEncoder", # Building part of bigger (tested) model.
"BlenderbotSmallDecoder", # Building part of bigger (tested) model. "BlenderbotSmallDecoderWrapper", # Building part of bigger (tested) model.
"BlenderbotEncoder", # Building part of bigger (tested) model. "BlenderbotEncoder", # Building part of bigger (tested) model.
"BlenderbotDecoder", # Building part of bigger (tested) model. "BlenderbotDecoderWrapper", # Building part of bigger (tested) model.
"MBartEncoder", # Building part of bigger (tested) model. "MBartEncoder", # Building part of bigger (tested) model.
"MBartDecoder", # Building part of bigger (tested) model. "MBartDecoderWrapper", # Building part of bigger (tested) model.
"PegasusEncoder", # Building part of bigger (tested) model. "PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoder", # Building part of bigger (tested) model. "PegasusDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model. "DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model. "DPRSpanPredictor", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model. "ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
...@@ -78,11 +78,14 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -78,11 +78,14 @@ IGNORE_NON_AUTO_CONFIGURED = [
"LEDEncoder", "LEDEncoder",
"LEDDecoder", "LEDDecoder",
"BartDecoder", "BartDecoder",
"BartDecoderWrapper",
"BartEncoder", "BartEncoder",
"BlenderbotSmallEncoder", "BlenderbotSmallEncoder",
"BlenderbotSmallDecoder", "BlenderbotSmallDecoder",
"BlenderbotSmallDecoderWrapper",
"BlenderbotEncoder", "BlenderbotEncoder",
"BlenderbotDecoder", "BlenderbotDecoder",
"BlenderbotDecoderWrapper",
"DPRContextEncoder", "DPRContextEncoder",
"DPREncoder", "DPREncoder",
"DPRReader", "DPRReader",
...@@ -93,9 +96,11 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -93,9 +96,11 @@ IGNORE_NON_AUTO_CONFIGURED = [
"MT5EncoderModel", "MT5EncoderModel",
"MBartEncoder", "MBartEncoder",
"MBartDecoder", "MBartDecoder",
"MBartDecoderWrapper",
"OpenAIGPTDoubleHeadsModel", "OpenAIGPTDoubleHeadsModel",
"PegasusEncoder", "PegasusEncoder",
"PegasusDecoder", "PegasusDecoder",
"PegasusDecoderWrapper",
"ProphetNetDecoder", "ProphetNetDecoder",
"ProphetNetEncoder", "ProphetNetEncoder",
"ProphetNetDecoderWrapper", "ProphetNetDecoderWrapper",
...@@ -205,9 +210,8 @@ def find_tested_models(test_file): ...@@ -205,9 +210,8 @@ def find_tested_models(test_file):
with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f: with open(os.path.join(PATH_TO_TESTS, test_file), "r", encoding="utf-8", newline="\n") as f:
content = f.read() content = f.read()
all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content) all_models = re.findall(r"all_model_classes\s+=\s+\(\s*\(([^\)]*)\)", content)
# Check with one less parenthesis # Check with one less parenthesis as well
if len(all_models) == 0: all_models += re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
all_models = re.findall(r"all_model_classes\s+=\s+\(([^\)]*)\)", content)
if len(all_models) > 0: if len(all_models) > 0:
model_tested = [] model_tested = []
for entry in all_models: for entry in all_models:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment