Unverified Commit 00031785 authored by demSd's avatar demSd Committed by GitHub
Browse files

BartForCausalLM analogs to `ProphetNetForCausalLM` (#9128)



* initiliaze bart4causalLM

* create BartDecoderWrapper, setters/getters

* delete spaces

* forward and additional methods

* update cache function, loss function, remove ngram* params in data class.

* add bartcausallm, bartdecoder testing

* correct bart for causal lm

* remove at

* add mbart as well

* up

* fix typo

* up

* correct

* add pegasusforcausallm

* add blenderbotforcausallm

* add blenderbotsmallforcausallm

* add marianforcausallm

* add test for MarianForCausalLM

* add Pegasus test

* add BlenderbotSmall test

* add blenderbot test

* fix a fail

* fix an import fail

* a fix

* fix

* Update modeling_pegasus.py

* fix models

* fix inputs_embeds setting getter

* adapt tests

* correct repo utils check

* finish test improvement

* fix tf models as well

* make style

* make fix-copies

* fix copies

* run all tests

* last changes

* fix all tests
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7898fc03
...@@ -130,6 +130,12 @@ BartForQuestionAnswering ...@@ -130,6 +130,12 @@ BartForQuestionAnswering
.. autoclass:: transformers.BartForQuestionAnswering .. autoclass:: transformers.BartForQuestionAnswering
:members: forward :members: forward
BartForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BartForCausalLM
:members: forward
TFBartModel TFBartModel
......
...@@ -98,6 +98,13 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward` ...@@ -98,6 +98,13 @@ See :obj:`transformers.BartForConditionalGeneration` for arguments to `forward`
:members: forward :members: forward
BlenderbotForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BlenderbotForCausalLM
:members: forward
TFBlenderbotModel TFBlenderbotModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -70,6 +70,13 @@ BlenderbotSmallForConditionalGeneration ...@@ -70,6 +70,13 @@ BlenderbotSmallForConditionalGeneration
:members: forward :members: forward
BlenderbotSmallForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.BlenderbotSmallForCausalLM
:members: forward
TFBlenderbotSmallModel TFBlenderbotSmallModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -193,6 +193,13 @@ MarianMTModel ...@@ -193,6 +193,13 @@ MarianMTModel
:members: forward :members: forward
MarianForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MarianForCausalLM
:members: forward
TFMarianModel TFMarianModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -124,6 +124,13 @@ MBartForSequenceClassification ...@@ -124,6 +124,13 @@ MBartForSequenceClassification
.. autoclass:: transformers.MBartForSequenceClassification .. autoclass:: transformers.MBartForSequenceClassification
MBartForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MBartForCausalLM
:members: forward
TFMBartModel TFMBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -131,6 +131,13 @@ PegasusForConditionalGeneration ...@@ -131,6 +131,13 @@ PegasusForConditionalGeneration
:members: forward :members: forward
PegasusForCausalLM
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.PegasusForCausalLM
:members: forward
TFPegasusModel TFPegasusModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -431,6 +431,7 @@ if is_torch_available(): ...@@ -431,6 +431,7 @@ if is_torch_available():
_import_structure["models.bart"].extend( _import_structure["models.bart"].extend(
[ [
"BART_PRETRAINED_MODEL_ARCHIVE_LIST", "BART_PRETRAINED_MODEL_ARCHIVE_LIST",
"BartForCausalLM",
"BartForConditionalGeneration", "BartForConditionalGeneration",
"BartForQuestionAnswering", "BartForQuestionAnswering",
"BartForSequenceClassification", "BartForSequenceClassification",
...@@ -468,6 +469,7 @@ if is_torch_available(): ...@@ -468,6 +469,7 @@ if is_torch_available():
"BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST",
"BlenderbotForConditionalGeneration", "BlenderbotForConditionalGeneration",
"BlenderbotModel", "BlenderbotModel",
"BlenderbotForCausalLM",
] ]
) )
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
...@@ -475,6 +477,7 @@ if is_torch_available(): ...@@ -475,6 +477,7 @@ if is_torch_available():
"BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST", "BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST",
"BlenderbotSmallForConditionalGeneration", "BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel", "BlenderbotSmallModel",
"BlenderbotSmallForCausalLM",
] ]
) )
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
...@@ -628,9 +631,10 @@ if is_torch_available(): ...@@ -628,9 +631,10 @@ if is_torch_available():
"LxmertXLayer", "LxmertXLayer",
] ]
) )
_import_structure["models.marian"].extend(["MarianModel", "MarianMTModel"]) _import_structure["models.marian"].extend(["MarianModel", "MarianMTModel", "MarianForCausalLM"])
_import_structure["models.mbart"].extend( _import_structure["models.mbart"].extend(
[ [
"MBartForCausalLM",
"MBartForConditionalGeneration", "MBartForConditionalGeneration",
"MBartForQuestionAnswering", "MBartForQuestionAnswering",
"MBartForSequenceClassification", "MBartForSequenceClassification",
...@@ -679,7 +683,9 @@ if is_torch_available(): ...@@ -679,7 +683,9 @@ if is_torch_available():
"load_tf_weights_in_openai_gpt", "load_tf_weights_in_openai_gpt",
] ]
) )
_import_structure["models.pegasus"].extend(["PegasusForConditionalGeneration", "PegasusModel"]) _import_structure["models.pegasus"].extend(
["PegasusForConditionalGeneration", "PegasusModel", "PegasusForCausalLM"]
)
_import_structure["models.prophetnet"].extend( _import_structure["models.prophetnet"].extend(
[ [
"PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", "PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1517,6 +1523,7 @@ if TYPE_CHECKING: ...@@ -1517,6 +1523,7 @@ if TYPE_CHECKING:
) )
from .models.bart import ( from .models.bart import (
BART_PRETRAINED_MODEL_ARCHIVE_LIST, BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BartForCausalLM,
BartForConditionalGeneration, BartForConditionalGeneration,
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
...@@ -1546,11 +1553,13 @@ if TYPE_CHECKING: ...@@ -1546,11 +1553,13 @@ if TYPE_CHECKING:
) )
from .models.blenderbot import ( from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForCausalLM,
BlenderbotForConditionalGeneration, BlenderbotForConditionalGeneration,
BlenderbotModel, BlenderbotModel,
) )
from .models.blenderbot_small import ( from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration, BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel, BlenderbotSmallModel,
) )
...@@ -1691,8 +1700,9 @@ if TYPE_CHECKING: ...@@ -1691,8 +1700,9 @@ if TYPE_CHECKING:
LxmertVisualFeatureEncoder, LxmertVisualFeatureEncoder,
LxmertXLayer, LxmertXLayer,
) )
from .models.marian import MarianModel, MarianMTModel from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.mbart import ( from .models.mbart import (
MBartForCausalLM,
MBartForConditionalGeneration, MBartForConditionalGeneration,
MBartForQuestionAnswering, MBartForQuestionAnswering,
MBartForSequenceClassification, MBartForSequenceClassification,
...@@ -1734,7 +1744,7 @@ if TYPE_CHECKING: ...@@ -1734,7 +1744,7 @@ if TYPE_CHECKING:
OpenAIGPTPreTrainedModel, OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt, load_tf_weights_in_openai_gpt,
) )
from .models.pegasus import PegasusForConditionalGeneration, PegasusModel from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
from .models.prophetnet import ( from .models.prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder, ProphetNetDecoder,
......
...@@ -33,6 +33,7 @@ from ..albert.modeling_albert import ( ...@@ -33,6 +33,7 @@ from ..albert.modeling_albert import (
AlbertModel, AlbertModel,
) )
from ..bart.modeling_bart import ( from ..bart.modeling_bart import (
BartForCausalLM,
BartForConditionalGeneration, BartForConditionalGeneration,
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
...@@ -50,8 +51,12 @@ from ..bert.modeling_bert import ( ...@@ -50,8 +51,12 @@ from ..bert.modeling_bert import (
BertModel, BertModel,
) )
from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder from ..bert_generation.modeling_bert_generation import BertGenerationDecoder, BertGenerationEncoder
from ..blenderbot.modeling_blenderbot import BlenderbotForConditionalGeneration, BlenderbotModel from ..blenderbot.modeling_blenderbot import BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel
from ..blenderbot_small.modeling_blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel from ..blenderbot_small.modeling_blenderbot_small import (
BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel,
)
from ..camembert.modeling_camembert import ( from ..camembert.modeling_camembert import (
CamembertForCausalLM, CamembertForCausalLM,
CamembertForMaskedLM, CamembertForMaskedLM,
...@@ -138,8 +143,9 @@ from ..longformer.modeling_longformer import ( ...@@ -138,8 +143,9 @@ from ..longformer.modeling_longformer import (
LongformerModel, LongformerModel,
) )
from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel from ..lxmert.modeling_lxmert import LxmertForPreTraining, LxmertForQuestionAnswering, LxmertModel
from ..marian.modeling_marian import MarianModel, MarianMTModel from ..marian.modeling_marian import MarianForCausalLM, MarianModel, MarianMTModel
from ..mbart.modeling_mbart import ( from ..mbart.modeling_mbart import (
MBartForCausalLM,
MBartForConditionalGeneration, MBartForConditionalGeneration,
MBartForQuestionAnswering, MBartForQuestionAnswering,
MBartForSequenceClassification, MBartForSequenceClassification,
...@@ -165,7 +171,7 @@ from ..mpnet.modeling_mpnet import ( ...@@ -165,7 +171,7 @@ from ..mpnet.modeling_mpnet import (
) )
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration, PegasusModel from ..pegasus.modeling_pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function from ..rag.modeling_rag import ( # noqa: F401 - need to import all RagModels to be in globals() function
RagModel, RagModel,
...@@ -425,6 +431,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict( ...@@ -425,6 +431,12 @@ MODEL_FOR_CAUSAL_LM_MAPPING = OrderedDict(
(BertGenerationConfig, BertGenerationDecoder), (BertGenerationConfig, BertGenerationDecoder),
(XLMProphetNetConfig, XLMProphetNetForCausalLM), (XLMProphetNetConfig, XLMProphetNetForCausalLM),
(ProphetNetConfig, ProphetNetForCausalLM), (ProphetNetConfig, ProphetNetForCausalLM),
(BartConfig, BartForCausalLM),
(MBartConfig, MBartForCausalLM),
(PegasusConfig, PegasusForCausalLM),
(MarianConfig, MarianForCausalLM),
(BlenderbotConfig, BlenderbotForCausalLM),
(BlenderbotSmallConfig, BlenderbotSmallForCausalLM),
] ]
) )
......
...@@ -31,6 +31,7 @@ if is_tokenizers_available(): ...@@ -31,6 +31,7 @@ if is_tokenizers_available():
if is_torch_available(): if is_torch_available():
_import_structure["modeling_bart"] = [ _import_structure["modeling_bart"] = [
"BART_PRETRAINED_MODEL_ARCHIVE_LIST", "BART_PRETRAINED_MODEL_ARCHIVE_LIST",
"BartForCausalLM",
"BartForConditionalGeneration", "BartForConditionalGeneration",
"BartForQuestionAnswering", "BartForQuestionAnswering",
"BartForSequenceClassification", "BartForSequenceClassification",
...@@ -53,6 +54,7 @@ if TYPE_CHECKING: ...@@ -53,6 +54,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_bart import ( from .modeling_bart import (
BART_PRETRAINED_MODEL_ARCHIVE_LIST, BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BartForCausalLM,
BartForConditionalGeneration, BartForConditionalGeneration,
BartForQuestionAnswering, BartForQuestionAnswering,
BartForSequenceClassification, BartForSequenceClassification,
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" PyTorch BART model. """ """ PyTorch BART model. """
import copy
import math import math
import random import random
import warnings import warnings
...@@ -37,6 +36,7 @@ from ...file_utils import ( ...@@ -37,6 +36,7 @@ from ...file_utils import (
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput, Seq2SeqQuestionAnsweringModelOutput,
...@@ -843,6 +843,30 @@ class BartDecoder(BartPretrainedModel): ...@@ -843,6 +843,30 @@ class BartDecoder(BartPretrainedModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -945,19 +969,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -945,19 +969,9 @@ class BartDecoder(BartPretrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask attention_mask = self._prepare_decoder_attention_mask(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = None )
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
...@@ -975,7 +989,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -975,7 +989,7 @@ class BartDecoder(BartPretrainedModel):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
...@@ -1012,7 +1026,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1012,7 +1026,7 @@ class BartDecoder(BartPretrainedModel):
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -1023,7 +1037,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -1023,7 +1037,7 @@ class BartDecoder(BartPretrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
...@@ -1039,7 +1053,9 @@ class BartDecoder(BartPretrainedModel): ...@@ -1039,7 +1053,9 @@ class BartDecoder(BartPretrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
all_cross_attentions += (layer_outputs[2],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
...@@ -1571,3 +1587,208 @@ class BartForQuestionAnswering(BartPretrainedModel): ...@@ -1571,3 +1587,208 @@ class BartForQuestionAnswering(BartPretrainedModel):
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
) )
class BartDecoderWrapper(BartPretrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = BartDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
class BartForCausalLM(BartPretrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
self.model = BartDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
config.vocab_size]``.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
Returns:
Example::
>>> from transformers import BartTokenizer, BartForCausalLM
>>> tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
>>> model = BartForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
...@@ -915,7 +915,7 @@ class TFBartDecoder(tf.keras.layers.Layer): ...@@ -915,7 +915,7 @@ class TFBartDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1] inputs["attention_mask"], tgt_len=input_shape[-1]
) )
......
...@@ -32,6 +32,7 @@ if is_torch_available(): ...@@ -32,6 +32,7 @@ if is_torch_available():
"BlenderbotForConditionalGeneration", "BlenderbotForConditionalGeneration",
"BlenderbotModel", "BlenderbotModel",
"BlenderbotPreTrainedModel", "BlenderbotPreTrainedModel",
"BlenderbotForCausalLM",
] ]
...@@ -46,6 +47,7 @@ if TYPE_CHECKING: ...@@ -46,6 +47,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_blenderbot import ( from .modeling_blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForCausalLM,
BlenderbotForConditionalGeneration, BlenderbotForConditionalGeneration,
BlenderbotModel, BlenderbotModel,
BlenderbotPreTrainedModel, BlenderbotPreTrainedModel,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" PyTorch Blenderbot model. """ """ PyTorch Blenderbot model. """
import copy
import math import math
import os import os
import random import random
...@@ -37,6 +38,7 @@ from ...file_utils import ( ...@@ -37,6 +38,7 @@ from ...file_utils import (
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
...@@ -805,13 +807,38 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -805,13 +807,38 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
attention_mask=None, attention_mask=None,
head_mask=None,
encoder_hidden_states=None, encoder_hidden_states=None,
encoder_attention_mask=None, encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None, encoder_head_mask=None,
past_key_values=None, past_key_values=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -838,12 +865,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -838,12 +865,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`): encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, encoder_sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
of the decoder. of the decoder.
...@@ -855,6 +876,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -855,6 +876,12 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`): encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``: on hidden heads. Mask values selected in ``[0, 1]``:
...@@ -907,19 +934,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -907,19 +934,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask attention_mask = self._prepare_decoder_attention_mask(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = None )
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
...@@ -929,7 +946,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -929,7 +946,6 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
# embed positions # embed positions
positions = self.embed_positions(input_shape, past_key_values_length) positions = self.embed_positions(input_shape, past_key_values_length)
# in constrast to Bart, Blenderbot applies layernorm on inputs_embeds
hidden_states = inputs_embeds + positions hidden_states = inputs_embeds + positions
hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = F.dropout(hidden_states, p=self.dropout, training=self.training)
...@@ -937,7 +953,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -937,7 +953,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
...@@ -974,7 +990,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -974,7 +990,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -985,10 +1001,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -985,10 +1001,10 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None), encoder_layer_head_mask=(encoder_head_mask[idx] if encoder_head_mask is not None else None),
past_key_value=past_key_value, past_key_value=past_key_value,
output_attentions=output_attentions, output_attentions=output_attentions,
...@@ -1001,7 +1017,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -1001,7 +1017,9 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
all_cross_attentions += (layer_outputs[2],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add final layer norm # add final layer norm
hidden_states = self.layer_norm(hidden_states) hidden_states = self.layer_norm(hidden_states)
...@@ -1336,3 +1354,210 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ...@@ -1336,3 +1354,210 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
) )
return reordered_past return reordered_past
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Blenderbot
class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
self.model = BlenderbotDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
config.vocab_size]``.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
Returns:
Example::
>>> from transformers import BlenderbotTokenizer, BlenderbotForCausalLM
>>> tokenizer = BlenderbotTokenizer.from_pretrained('facebook/bart-large')
>>> model = BlenderbotForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
...@@ -925,7 +925,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer): ...@@ -925,7 +925,7 @@ class TFBlenderbotDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1] inputs["attention_mask"], tgt_len=input_shape[-1]
) )
......
...@@ -31,6 +31,7 @@ if is_torch_available(): ...@@ -31,6 +31,7 @@ if is_torch_available():
"BlenderbotSmallForConditionalGeneration", "BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel", "BlenderbotSmallModel",
"BlenderbotSmallPreTrainedModel", "BlenderbotSmallPreTrainedModel",
"BlenderbotSmallForCausalLM",
] ]
if is_tf_available(): if is_tf_available():
...@@ -46,6 +47,7 @@ if TYPE_CHECKING: ...@@ -46,6 +47,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_blenderbot_small import ( from .modeling_blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration, BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel, BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel, BlenderbotSmallPreTrainedModel,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
""" PyTorch BlenderbotSmall model. """ """ PyTorch BlenderbotSmall model. """
import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -35,6 +36,7 @@ from ...file_utils import ( ...@@ -35,6 +36,7 @@ from ...file_utils import (
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
...@@ -805,6 +807,31 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -805,6 +807,31 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -907,19 +934,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -907,19 +934,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask attention_mask = self._prepare_decoder_attention_mask(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = None )
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
...@@ -938,7 +955,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -938,7 +955,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
if head_mask is not None: if head_mask is not None:
...@@ -974,7 +991,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -974,7 +991,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -985,7 +1002,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -985,7 +1002,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
...@@ -1001,7 +1018,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -1001,7 +1018,9 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
all_cross_attentions += (layer_outputs[2],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
...@@ -1310,3 +1329,210 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ...@@ -1310,3 +1329,210 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
) )
return reordered_past return reordered_past
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->BlenderbotSmall
class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = BlenderbotSmallDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
self.model = BlenderbotSmallDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
config.vocab_size]``.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
Returns:
Example::
>>> from transformers import BlenderbotSmallTokenizer, BlenderbotSmallForCausalLM
>>> tokenizer = BlenderbotSmallTokenizer.from_pretrained('facebook/bart-large')
>>> model = BlenderbotSmallForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
...@@ -925,7 +925,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer): ...@@ -925,7 +925,7 @@ class TFBlenderbotSmallDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1] inputs["attention_mask"], tgt_len=input_shape[-1]
) )
......
...@@ -39,6 +39,7 @@ if is_torch_available(): ...@@ -39,6 +39,7 @@ if is_torch_available():
"MarianModel", "MarianModel",
"MarianMTModel", "MarianMTModel",
"MarianPreTrainedModel", "MarianPreTrainedModel",
"MarianForCausalLM",
] ]
if is_tf_available(): if is_tf_available():
...@@ -54,6 +55,7 @@ if TYPE_CHECKING: ...@@ -54,6 +55,7 @@ if TYPE_CHECKING:
if is_torch_available(): if is_torch_available():
from .modeling_marian import ( from .modeling_marian import (
MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST, MARIAN_PRETRAINED_MODEL_ARCHIVE_LIST,
MarianForCausalLM,
MarianModel, MarianModel,
MarianMTModel, MarianMTModel,
MarianPreTrainedModel, MarianPreTrainedModel,
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
"""PyTorch MarianMTModel model, ported from the Marian C++ repo.""" """PyTorch MarianMTModel model, ported from the Marian C++ repo."""
import copy
import math import math
import random import random
from typing import Optional, Tuple from typing import Optional, Tuple
...@@ -36,6 +37,7 @@ from ...file_utils import ( ...@@ -36,6 +37,7 @@ from ...file_utils import (
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
...@@ -809,6 +811,31 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -809,6 +811,31 @@ class MarianDecoder(MarianPreTrainedModel):
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights() self.init_weights()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = None
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1])
combined_attention_mask = (
expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask
)
return combined_attention_mask
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -911,19 +938,9 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -911,19 +938,9 @@ class MarianDecoder(MarianPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
# create causal mask attention_mask = self._prepare_decoder_attention_mask(
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask, input_shape, inputs_embeds, past_key_values_length
combined_attention_mask = None )
if input_shape[-1] > 1:
combined_attention_mask = _make_causal_mask(
input_shape, inputs_embeds.dtype, past_key_values_length=past_key_values_length
).to(self.device)
if attention_mask is not None and combined_attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask = combined_attention_mask + _expand_mask(
attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
)
# expand encoder attention mask # expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None: if encoder_hidden_states is not None and encoder_attention_mask is not None:
...@@ -940,7 +957,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -940,7 +957,7 @@ class MarianDecoder(MarianPreTrainedModel):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
next_decoder_cache = () if use_cache else None next_decoder_cache = () if use_cache else None
# check if head_mask has a correct number of layers specified if desired # check if head_mask has a correct number of layers specified if desired
...@@ -977,7 +994,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -977,7 +994,7 @@ class MarianDecoder(MarianPreTrainedModel):
layer_outputs = torch.utils.checkpoint.checkpoint( layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), create_custom_forward(decoder_layer),
hidden_states, hidden_states,
combined_attention_mask, attention_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
head_mask[idx] if head_mask is not None else None, head_mask[idx] if head_mask is not None else None,
...@@ -988,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -988,7 +1005,7 @@ class MarianDecoder(MarianPreTrainedModel):
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=combined_attention_mask, attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask, encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None), layer_head_mask=(head_mask[idx] if head_mask is not None else None),
...@@ -1004,7 +1021,9 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -1004,7 +1021,9 @@ class MarianDecoder(MarianPreTrainedModel):
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
all_cross_attentions += (layer_outputs[2],)
if encoder_hidden_states is not None:
all_cross_attentions += (layer_outputs[2],)
# add hidden states from the last decoder layer # add hidden states from the last decoder layer
if output_hidden_states: if output_hidden_states:
...@@ -1321,3 +1340,210 @@ class MarianMTModel(MarianPreTrainedModel): ...@@ -1321,3 +1340,210 @@ class MarianMTModel(MarianPreTrainedModel):
tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],
) )
return reordered_past return reordered_past
# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->Marian
class MarianDecoderWrapper(MarianPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
used in combination with the :class:`~transformers.EncoderDecoderModel` framework.
"""
def __init__(self, config):
super().__init__(config)
self.decoder = MarianDecoder(config)
def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian
class MarianForCausalLM(MarianPreTrainedModel):
def __init__(self, config):
super().__init__(config)
config = copy.deepcopy(config)
config.is_decoder = True
config.is_encoder_decoder = False
self.model = MarianDecoderWrapper(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.init_weights()
def get_input_embeddings(self):
return self.model.decoder.embed_tokens
def set_input_embeddings(self, value):
self.model.decoder.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model.decoder = decoder
def get_decoder(self):
return self.model.decoder
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
head_mask=None,
encoder_head_mask=None,
past_key_values=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
provide it.
Indices can be obtained using :class:`~transformers.ProphetNetTokenizer`. See
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
for details.
`What are input IDs? <../glossary.html#input-ids>`__
attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the model is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used
in the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
encoder_head_mask (:obj:`torch.Tensor` of shape :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the attention modules in encoder to avoid performing cross-attention
on hidden heads. Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the heas is **masked**.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last ``decoder_input_ids``
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for computing the masked language modeling loss. Indices should either be in ``[0, ...,
config.vocab_size]`` or -100 (see ``input_ids`` docstring). Tokens with indices set to ``-100`` are
ignored (masked), the loss is only computed for the tokens with labels in ``[0, ...,
config.vocab_size]``.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
for more detail.
return_dict (:obj:`bool`, `optional`):
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
Returns:
Example::
>>> from transformers import MarianTokenizer, MarianForCausalLM
>>> tokenizer = MarianTokenizer.from_pretrained('facebook/bart-large')
>>> model = MarianForCausalLM.from_pretrained('facebook/bart-large', add_cross_attention=False)
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
>>> outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
head_mask=head_mask,
encoder_head_mask=encoder_head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
logits = self.lm_head(outputs[0])
loss = None
if labels is not None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
)
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=None, **kwargs):
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_ids.shape)
if past:
input_ids = input_ids[:, -1:]
# first step, decoder_cached_states are empty
return {
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
"attention_mask": attention_mask,
"past_key_values": past,
"use_cache": use_cache,
}
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past
...@@ -943,7 +943,7 @@ class TFMarianDecoder(tf.keras.layers.Layer): ...@@ -943,7 +943,7 @@ class TFMarianDecoder(tf.keras.layers.Layer):
tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1] tf.ones((input_shape[0], input_shape[1] + past_key_values_length)), tgt_len=input_shape[-1]
) )
if inputs["attention_mask"] is not None and input_shape[-1] > 1: if inputs["attention_mask"] is not None:
combined_attention_mask = combined_attention_mask + _expand_mask( combined_attention_mask = combined_attention_mask + _expand_mask(
inputs["attention_mask"], tgt_len=input_shape[-1] inputs["attention_mask"], tgt_len=input_shape[-1]
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment