Unverified Commit 62f5ae68 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Seq2Seq] Fix a couple of bugs and clean examples (#7474)



* clean T5

* fix t5 tests

* fix index typo

* fix tf common test

* fix examples

* change positional ordering for Bart and FSTM

* add signature test

* clean docs and add tests

* add docs to encoder decoder

* clean docs

* correct two doc strings

* remove sig test for TF Elektra & Funnel

* fix tf t5 slow tests

* fix input_ids to inputs in tf

* Update src/transformers/modeling_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_bart.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* implement lysandre results

* make style

* fix encoder decoder typo

* fix tf slow tests

* fix slow tests

* renaming

* remove unused input
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent a42f62d3
......@@ -101,25 +101,25 @@ BART_INPUTS_DOCSTRING = r"""
Mask to avoid performing attention on padding token indices in input_ids.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for translation and summarization training. By default, the model will create this tensor by shifting the input_ids right, following the paper.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If ``past_key_values`` are used, the user can optionally input only the last
If :obj:`past_key_values` are used, the user can optionally input only the last
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
``past_key_values``).
If :obj:`use_cache` is True, :obj:`past_key_values` are returned and can be used to speed up decoding (see
:obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
......@@ -874,8 +874,8 @@ class BartModel(PretrainedBartModel):
input_ids,
attention_mask=None,
decoder_input_ids=None,
encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None,
encoder_outputs: Optional[Tuple] = None,
past_key_values=None,
use_cache=None,
output_attentions=None,
......@@ -1004,9 +1004,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
labels=None,
use_cache=None,
......@@ -1171,9 +1171,9 @@ class BartForSequenceClassification(PretrainedBartModel):
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
labels=None,
use_cache=None,
output_attentions=None,
......@@ -1257,9 +1257,9 @@ class BartForQuestionAnswering(PretrainedBartModel):
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
start_positions=None,
end_positions=None,
use_cache=None,
......
......@@ -251,11 +251,11 @@ CTRL_START_DOCSTRING = r"""
CTRL_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
:obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
Indices of input sequence tokens in the vocabulary.
If ``past_key_values`` is used, only input IDs that do not have their past calculated should be passed as
If :obj:`past_key_values` is used, only input IDs that do not have their past calculated should be passed as
``input_ids``.
Indices can be obtained using :class:`~transformers.CTRLTokenizer`.
......@@ -265,7 +265,7 @@ CTRL_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see ``past_key_values`` output below). Can be used to speed up sequential decoding.
(see :obj:`past_key_values` output below). Can be used to speed up sequential decoding.
The ``input_ids`` which have their past given to this model should not be passed as input ids as they have
already been computed.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
......@@ -301,8 +301,8 @@ CTRL_INPUTS_DOCSTRING = r"""
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, ``past_key_values`` key value states are returned and can be used to speed up
decoding (see ``past_key_values``).
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
......
......@@ -69,10 +69,6 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
:meth:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
......@@ -81,11 +77,6 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`):
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
:obj:`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`)
is a tensor of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for sequence to sequence training to the decoder.
Indices can be obtained using :class:`~transformers.PretrainedTokenizer`.
......@@ -94,6 +85,21 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`):
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
:obj:`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`)
is a tensor of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded
representation. This is useful if you want more control over how to convert :obj:`decoder_input_ids`
......@@ -103,6 +109,15 @@ ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with
labels in ``[0, ..., config.vocab_size]``
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`):
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
more detail.
return_dict (:obj:`bool`, `optional`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
plain tuple.
......@@ -328,13 +343,17 @@ class EncoderDecoderModel(PreTrainedModel):
def forward(
self,
input_ids=None,
inputs_embeds=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None, # TODO: (PVP) implement :obj:`use_cache`
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None, # TODO: (PVP) implement :obj:`use_cache`
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
......@@ -378,20 +397,24 @@ class EncoderDecoderModel(PreTrainedModel):
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_encoder,
)
hidden_states = encoder_outputs[0]
encoder_hidden_states = encoder_outputs[0]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
inputs_embeds=decoder_inputs_embeds,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
labels=labels,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
**kwargs_decoder,
)
......@@ -423,7 +446,7 @@ class EncoderDecoderModel(PreTrainedModel):
"encoder_outputs": encoder_outputs,
}
# Ideally all models should have a `use_cache`
# Ideally all models should have a :obj:`use_cache`
# leave following to ifs until all have it implemented
if "use_cache" in decoder_inputs:
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
......
......@@ -227,10 +227,6 @@ FSMT_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for translation and summarization training. By default, the model will create this tensor by
shifting the input_ids right, following the paper.
......@@ -240,6 +236,10 @@ FSMT_INPUTS_DOCSTRING = r"""
If you want to change padding behavior, you should read
:func:`modeling_fstm._prepare_fstm_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
encoder_outputs (:obj:`Tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (:obj:`Tuple(torch.FloatTensor)` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
......@@ -248,8 +248,8 @@ FSMT_INPUTS_DOCSTRING = r"""
:obj:`(batch_size, 1)` instead of all :obj:`decoder_input_ids` of shape
:obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, ``past_key_values`` key value states are returned and can be used to speed up
decoding (see ``past_key_values``).
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
......@@ -910,8 +910,8 @@ class FSMTModel(PretrainedFSMTModel):
input_ids,
attention_mask=None,
decoder_input_ids=None,
encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None,
encoder_outputs: Optional[Tuple] = None,
past_key_values=None,
use_cache=None,
output_attentions=None,
......@@ -1045,9 +1045,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
self,
input_ids,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
labels=None,
use_cache=None,
......
......@@ -187,16 +187,16 @@ class FunnelAttentionStructure(nn.Module):
# dividide.
self.pooling_mult = None
def init_attention_inputs(self, input_embeds, attention_mask=None, token_type_ids=None):
def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None):
""" Returns the attention inputs associated to the inputs of the model. """
# input_embeds has shape batch_size x seq_len x d_model
# inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
self.seq_len = seq_len = input_embeds.size(1)
position_embeds = self.get_position_embeds(seq_len, input_embeds.dtype, input_embeds.device)
self.seq_len = seq_len = inputs_embeds.size(1)
position_embeds = self.get_position_embeds(seq_len, inputs_embeds.dtype, inputs_embeds.device)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = (
F.pad(input_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
F.pad(inputs_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
if self.config.separate_cls
else None
)
......
......@@ -365,7 +365,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -407,11 +407,11 @@ GPT2_START_DOCSTRING = r"""
GPT2_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
Indices of input sequence tokens in the vocabulary.
If ``past_key_values`` is used, only ``input_ids`` that do not have their past calculated should be passed
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be passed
as ``input_ids``.
Indices can be obtained using :class:`~transformers.GPT2Tokenizer`.
......@@ -421,7 +421,7 @@ GPT2_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model
(see ``past_key_values`` output below). Can be used to speed up sequential decoding.
(see :obj:`past_key_values` output below). Can be used to speed up sequential decoding.
The ``input_ids`` which have their past given to this model should not be passed as ``input_ids`` as they
have already been computed.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
......@@ -457,11 +457,11 @@ GPT2_INPUTS_DOCSTRING = r"""
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
vectors than the model's internal embedding lookup matrix.
If ``past_key_values`` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
``past_key_values``).
If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
:obj:`past_key_values`).
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, ``past_key_values`` key value states are returned and can be used to speed up
decoding (see ``past_key_values``).
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
......
......@@ -80,7 +80,7 @@ class BaseModelOutputWithPast(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -110,13 +110,13 @@ class Seq2SeqModelOutput(ModelOutput):
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -196,7 +196,7 @@ class CausalLMOutputWithPast(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -261,7 +261,7 @@ class Seq2SeqLMOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -371,7 +371,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -517,7 +517,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......
......@@ -52,7 +52,7 @@ class RetrievAugLMMarginOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
(see ``past_key_values`` input) to speed up sequential decoding.
(see :obj:`past_key_values` input) to speed up sequential decoding.
retrieved_doc_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs, hidden_size)`, `optional`, returned when `output_retrieved=True`):
Embedded documents retrieved by the retriever.
Is used with ``question_encoder_last_hidden_state`` to compute the ``doc_scores``.
......@@ -137,7 +137,7 @@ class RetrievAugLMOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
(see ``past_key_values`` input) to speed up sequential decoding.
(see :obj:`past_key_values` input) to speed up sequential decoding.
retrieved_doc_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.n_docs, hidden_size)`, `optional`, returned when `output_retrieved=True`):
Embedded documents retrieved by the retriever.
Is used with ``question_encoder_last_hidden_state`` to compute the ``doc_scores``.
......@@ -447,8 +447,8 @@ RAG_FORWARD_INPUTS_DOCSTRING = r"""
to the forward pass. :obj:`context_attention_mask` are returned by
:meth:`~transformers.RagRetriever.__call__`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, ``past_key_values`` key value states are returned and can be used to speed up
decoding (see ``past_key_values``).
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
......
......@@ -1959,8 +1959,8 @@ REFORMER_INPUTS_DOCSTRING = r"""
Contains precomputed hidden-states and buckets (only relevant for LSH Self-Attention). Can be used to speed
up sequential decoding.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, ``past_key_values`` key value states are returned and can be used to speed up
decoding (see ``past_key_values``).
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
......
......@@ -202,8 +202,9 @@ class T5LayerFF(nn.Module):
class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False):
def __init__(self, config: T5Config, has_relative_attention_bias=False, is_bidirectional=False):
super().__init__()
self.is_bidirectional = is_bidirectional
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
......@@ -293,7 +294,7 @@ class T5Attention(nn.Module):
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position, # shape (qlen, klen)
bidirectional=not self.is_decoder,
bidirectional=self.is_bidirectional,
num_buckets=self.relative_attention_num_buckets,
)
rp_bucket = rp_bucket.to(self.relative_attention_bias.weight.device)
......@@ -307,7 +308,7 @@ class T5Attention(nn.Module):
mask=None,
kv=None,
position_bias=None,
past_key_value_state=None,
past_key_value=None,
head_mask=None,
query_length=None,
use_cache=False,
......@@ -318,17 +319,17 @@ class T5Attention(nn.Module):
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
# past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = input.size()
if past_key_value_state is not None:
if past_key_value is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert (
len(past_key_value_state) == 2
), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value_state)
len(past_key_value) == 2
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value)
)
real_qlen = qlen + past_key_value_state[0].shape[2] if query_length is None else query_length
real_qlen = qlen + past_key_value[0].shape[2] if query_length is None else query_length
else:
real_qlen = qlen
......@@ -350,18 +351,18 @@ class T5Attention(nn.Module):
if kv is None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
elif past_key_value_state is None:
elif past_key_value is None:
k = v = kv
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
if past_key_value_state is not None:
if past_key_value is not None:
if kv is None:
k_, v_ = past_key_value_state
k_, v_ = past_key_value
k = torch.cat([k_, k], dim=2) # (bs, n_heads, klen, dim_per_head)
v = torch.cat([v_, v], dim=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = past_key_value_state
k, v = past_key_value
if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),)
......@@ -380,8 +381,8 @@ class T5Attention(nn.Module):
# if key and values are already calculated
# we want only the last query position bias
if past_key_value_state is not None:
position_bias = position_bias[:, :, -1:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -qlen:, :]
if mask is not None:
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
......@@ -411,7 +412,9 @@ class T5Attention(nn.Module):
class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.SelfAttention = T5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=not config.is_decoder
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
......@@ -421,7 +424,7 @@ class T5LayerSelfAttention(nn.Module):
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
......@@ -431,7 +434,7 @@ class T5LayerSelfAttention(nn.Module):
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
......@@ -444,7 +447,9 @@ class T5LayerSelfAttention(nn.Module):
class T5LayerCrossAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
self.EncDecAttention = T5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, is_bidirectional=True
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
......@@ -455,7 +460,7 @@ class T5LayerCrossAttention(nn.Module):
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
past_key_value=None,
use_cache=False,
query_length=None,
output_attentions=False,
......@@ -467,7 +472,7 @@ class T5LayerCrossAttention(nn.Module):
kv=kv,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
past_key_value=past_key_value,
use_cache=use_cache,
query_length=query_length,
output_attentions=output_attentions,
......@@ -498,33 +503,33 @@ class T5Block(nn.Module):
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
head_mask=None,
past_key_value_state=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
):
if past_key_value_state is not None:
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
if past_key_value is not None:
assert self.is_decoder, "Only decoder can use `past_key_values`"
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
expected_num_past_key_value_states,
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
len(past_key_value_state),
expected_num_past_key_values,
"2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
len(past_key_value),
)
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
assert len(past_key_value) == expected_num_past_key_values, error_message
self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:]
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
......@@ -545,7 +550,7 @@ class T5Block(nn.Module):
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
past_key_value_state=cross_attn_past_key_value_state,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -673,7 +678,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
past_key_value_states=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
......@@ -688,17 +693,18 @@ class T5Stack(T5PreTrainedModel):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
if self.is_decoder:
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
......@@ -706,18 +712,13 @@ class T5Stack(T5PreTrainedModel):
batch_size, seq_length = input_shape
if past_key_value_states is not None:
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
input_shape, (batch_size, 1)
)
# required mask seq length can be calculated via length of past
# key value states and seq_length = 1 for the last token
mask_seq_length = past_key_value_states[0][0].shape[2] + seq_length
else:
mask_seq_length = seq_length
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
assert self.is_decoder, ":obj:`use_cache` can only be set to `True` if {} is used as a decoder".format(
self
)
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device)
......@@ -727,9 +728,9 @@ class T5Stack(T5PreTrainedModel):
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)
# initialize past_key_value_states with `None` if past does not exist
if past_key_value_states is None:
past_key_value_states = [None] * len(self.block)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
......@@ -749,7 +750,7 @@ class T5Stack(T5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -761,7 +762,7 @@ class T5Stack(T5PreTrainedModel):
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
past_key_value_state=past_key_value_state,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
)
......@@ -845,10 +846,6 @@ T5_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: `attentions`)
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`):
Provide for sequence to sequence training. T5 uses the :obj:`pad_token_id` as the starting token for
:obj:`decoder_input_ids` generation.
......@@ -861,15 +858,23 @@ T5_INPUTS_DOCSTRING = r"""
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: `attentions`)
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, ``past_key_values`` key value states are returned and can be used to speed up
decoding (see ``past_key_values``).
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
......@@ -883,13 +888,11 @@ T5_INPUTS_DOCSTRING = r"""
associated vectors than the model's internal embedding lookup matrix.
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both
unset, :obj:`decoder_input_embeds` takes the value of :obj:`input_embeds`.
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
unset, :obj:`decoder_inputs_embeds` takes the value of :obj:`inputs_embeds`.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
use_cache (:obj:`bool`, `optional`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
......@@ -952,14 +955,14 @@ class T5Model(T5PreTrainedModel):
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
use_cache=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
head_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -975,10 +978,11 @@ class T5Model(T5PreTrainedModel):
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5Model.from_pretrained('t5-small')
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
>>> outputs = model(input_ids=input_ids)
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
>>> last_hidden_states = outputs.last_hidden_state
"""
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
......@@ -1017,26 +1021,12 @@ class T5Model(T5PreTrainedModel):
hidden_states = encoder_outputs[0]
# If the model is only provided with either input_ids or inputs_embeds,
# use them as the inputs of the decoder. self.encoder checks for input_ids XOR inputs_embeds
if (decoder_input_ids is None) and (decoder_inputs_embeds is None):
decoder_input_ids = input_ids
decoder_inputs_embeds = inputs_embeds
# If decoding with past key value states, only the last tokens
# should be given as an input
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=past_key_values,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
......@@ -1108,15 +1098,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
self,
input_ids=None,
attention_mask=None,
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
use_cache=None,
labels=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
head_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
......@@ -1139,14 +1129,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True)
>>> input_ids = tokenizer.encode("Hello, my dog is cute", return_tensors="pt") # Batch size 1
>>> outputs = model(input_ids=input_ids, labels=input_ids)
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='pt').input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True)
>>> input_ids = tokenizer.encode("summarize: Hello, my dog is cute", return_tensors="pt") # Batch size 1
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
"""
......@@ -1212,7 +1202,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=past_key_values,
past_key_values=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
......@@ -1250,6 +1240,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):
# cut decoder_input_ids if past is used
if past is not None:
input_ids = input_ids[:, -1:]
return {
"decoder_input_ids": input_ids,
"past_key_values": past,
......
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -743,7 +744,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
@replace_return_docstrings(output_type=TFElectraForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -753,6 +754,7 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Returns:
......@@ -769,8 +771,15 @@ class TFElectraForPreTraining(TFElectraPreTrainedModel):
>>> scores = outputs[0]
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
discriminator_hidden_states = self.electra(
input_ids,
inputs,
attention_mask,
token_type_ids,
position_ids,
......@@ -847,7 +856,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
)
def call(
self,
input_ids,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -858,6 +867,7 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
......@@ -868,16 +878,22 @@ class TFElectraForMaskedLM(TFElectraPreTrainedModel, TFMaskedLanguageModelingLos
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(input_ids, (tuple, list)):
labels = input_ids[9] if len(input_ids) > 9 else labels
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
if len(input_ids) > 9:
input_ids = input_ids[:9]
elif isinstance(input_ids, (dict, BatchEncoding)):
labels = input_ids.pop("labels", labels)
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
generator_hidden_states = self.electra(
input_ids,
inputs,
attention_mask,
token_type_ids,
position_ids,
......@@ -952,7 +968,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
)
def call(
self,
input_ids,
inputs,
attention_mask=None,
token_type_ids=None,
position_ids=None,
......@@ -963,6 +979,7 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size,)`, `optional`):
......@@ -973,16 +990,22 @@ class TFElectraForSequenceClassification(TFElectraPreTrainedModel, TFSequenceCla
"""
return_dict = return_dict if return_dict is not None else self.electra.config.return_dict
if isinstance(input_ids, (tuple, list)):
labels = input_ids[9] if len(input_ids) > 9 else labels
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
if len(input_ids) > 9:
input_ids = input_ids[:9]
elif isinstance(input_ids, (dict, BatchEncoding)):
labels = input_ids.pop("labels", labels)
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
outputs = self.electra(
input_ids,
inputs,
attention_mask,
token_type_ids,
position_ids,
......
......@@ -14,6 +14,7 @@
# limitations under the License.
""" TF 2.0 Funnel model. """
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple
......@@ -173,16 +174,16 @@ class TFFunnelAttentionStructure:
# dividide.
self.pooling_mult = None
def init_attention_inputs(self, input_embeds, attention_mask=None, token_type_ids=None, training=False):
def init_attention_inputs(self, inputs_embeds, attention_mask=None, token_type_ids=None, training=False):
""" Returns the attention inputs associated to the inputs of the model. """
# input_embeds has shape batch_size x seq_len x d_model
# inputs_embeds has shape batch_size x seq_len x d_model
# attention_mask and token_type_ids have shape batch_size x seq_len
self.pooling_mult = 1
self.seq_len = seq_len = input_embeds.shape[1]
position_embeds = self.get_position_embeds(seq_len, dtype=input_embeds.dtype, training=training)
self.seq_len = seq_len = inputs_embeds.shape[1]
position_embeds = self.get_position_embeds(seq_len, dtype=inputs_embeds.dtype, training=training)
token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
cls_mask = (
tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=input_embeds.dtype), [[1, 0], [1, 0]])
tf.pad(tf.ones([seq_len - 1, seq_len - 1], dtype=inputs_embeds.dtype), [[1, 0], [1, 0]])
if self.separate_cls
else None
)
......@@ -1184,7 +1185,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
@replace_return_docstrings(output_type=TFFunnelForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(
self,
input_ids,
inputs,
attention_mask=None,
token_type_ids=None,
inputs_embeds=None,
......@@ -1192,6 +1193,7 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs
):
r"""
Returns:
......@@ -1209,8 +1211,14 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel):
"""
return_dict = return_dict if return_dict is not None else self.funnel.return_dict
if inputs is None and "input_ids" in kwargs and isinstance(kwargs["input_ids"], (dict, BatchEncoding)):
warnings.warn(
"Using `input_ids` as a dictionary keyword argument is deprecated. Please use `inputs` instead."
)
inputs = kwargs["input_ids"]
discriminator_hidden_states = self.funnel(
input_ids,
inputs,
attention_mask,
token_type_ids,
inputs_embeds,
......
......@@ -427,7 +427,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......
......@@ -84,7 +84,7 @@ class TFBaseModelOutputWithPast(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -114,13 +114,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -200,7 +200,7 @@ class TFCausalLMOutputWithPast(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
``past_key_values`` input) to speed up sequential decoding.
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -265,7 +265,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -372,7 +372,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -518,7 +518,7 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``past_key_values`` input) to speed up sequential decoding.
used (see :obj:`past_key_values` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......
......@@ -117,8 +117,9 @@ class TFT5LayerFF(tf.keras.layers.Layer):
class TFT5Attention(tf.keras.layers.Layer):
NEW_ID = itertools.count()
def __init__(self, config, has_relative_attention_bias=False, **kwargs):
def __init__(self, config, has_relative_attention_bias=False, is_bidirectional=False, **kwargs):
super().__init__(**kwargs)
self.is_bidirectional = is_bidirectional
self.layer_id = next(TFT5Attention.NEW_ID)
self.is_decoder = config.is_decoder
self.use_cache = config.use_cache
......@@ -202,7 +203,7 @@ class TFT5Attention(tf.keras.layers.Layer):
relative_position = memory_position - context_position # shape (qlen, klen)
rp_bucket = self._relative_position_bucket(
relative_position,
bidirectional=not self.is_decoder,
bidirectional=self.is_bidirectional,
num_buckets=self.relative_attention_num_buckets,
)
values = self.relative_attention_bias(rp_bucket) # shape (qlen, klen, num_heads)
......@@ -215,8 +216,7 @@ class TFT5Attention(tf.keras.layers.Layer):
mask=None,
kv=None,
position_bias=None,
cache=None,
past_key_value_state=None,
past_key_value=None,
head_mask=None,
query_length=None,
use_cache=False,
......@@ -228,17 +228,17 @@ class TFT5Attention(tf.keras.layers.Layer):
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
# past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
# past_key_value[0] is (bs, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = shape_list(input)
if past_key_value_state is not None:
if past_key_value is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert (
len(past_key_value_state) == 2
), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value_state)
len(past_key_value) == 2
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value)
)
real_qlen = qlen + shape_list(past_key_value_state[0])[2] if query_length is None else query_length
real_qlen = qlen + shape_list(past_key_value[0])[2] if query_length is None else query_length
else:
real_qlen = qlen
......@@ -260,18 +260,18 @@ class TFT5Attention(tf.keras.layers.Layer):
if kv is None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
elif past_key_value_state is None:
elif past_key_value is None:
k = v = kv
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
if past_key_value_state is not None:
if past_key_value is not None:
if kv is None:
k_, v_ = past_key_value_state
k_, v_ = past_key_value
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = past_key_value_state
k, v = past_key_value
# to cope with keras serialization
if self.is_decoder and cast_bool_to_primitive(use_cache, self.use_cache) is True:
......@@ -288,8 +288,8 @@ class TFT5Attention(tf.keras.layers.Layer):
# if key and values are already calculated
# we want only the last query position bias
if past_key_value_state is not None:
position_bias = position_bias[:, :, -1:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -qlen:, :]
if mask is not None:
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
......@@ -322,6 +322,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.SelfAttention = TFT5Attention(
config,
has_relative_attention_bias=has_relative_attention_bias,
is_bidirectional=not config.is_decoder,
name="SelfAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
......@@ -333,7 +334,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
training=False,
......@@ -344,7 +345,7 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
......@@ -361,6 +362,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
self.EncDecAttention = TFT5Attention(
config,
has_relative_attention_bias=has_relative_attention_bias,
is_bidirectional=True,
name="EncDecAttention",
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
......@@ -373,7 +375,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
past_key_value=None,
query_length=None,
use_cache=False,
output_attentions=False,
......@@ -386,7 +388,7 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
kv=kv,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
past_key_value=past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -430,34 +432,34 @@ class TFT5Block(tf.keras.layers.Layer):
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
head_mask=None,
past_key_value_state=None,
past_key_value=None,
use_cache=False,
output_attentions=False,
training=False,
):
if past_key_value_state is not None:
if past_key_value is not None:
assert self.is_decoder, "Only decoder can use `past_key_values`"
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
expected_num_past_key_values,
"2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
len(past_key_value_state),
len(past_key_value),
)
assert len(past_key_value_state) == expected_num_past_key_values, error_message
assert len(past_key_value) == expected_num_past_key_values, error_message
self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:]
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state,
past_key_value=self_attn_past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
......@@ -479,7 +481,7 @@ class TFT5Block(tf.keras.layers.Layer):
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
past_key_value_state=cross_attn_past_key_value_state,
past_key_value=cross_attn_past_key_value,
query_length=query_length,
use_cache=use_cache,
output_attentions=output_attentions,
......@@ -618,34 +620,38 @@ class TFT5MainLayer(tf.keras.layers.Layer):
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
if "past_key_value_states" in inputs:
if "past_key_values" in inputs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
"The `past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = inputs.pop("past_key_value_states")
past_key_values = inputs.pop("past_key_values")
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
if "past_key_values" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
"The `past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")
past_key_values = kwargs.pop("past_key_values")
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
use_cache = use_cache if use_cache is not None else self.use_cache
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both inputs and inputs_embeds at the same time")
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(
f"You cannot specify both {err_msg_prefix}inputs and {err_msg_prefix}inputs_embeds at the same time"
)
elif input_ids is not None:
input_shape = shape_list(input_ids)
input_ids = tf.reshape(input_ids, (-1, input_shape[-1]))
elif inputs_embeds is not None:
input_shape = shape_list(inputs_embeds)[:-1]
else:
raise ValueError("You have to specify either inputs or inputs_embeds")
err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}inputs or {err_msg_prefix}inputs_embeds")
if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
......@@ -653,15 +659,10 @@ class TFT5MainLayer(tf.keras.layers.Layer):
batch_size, seq_length = input_shape
if past_key_values is not None:
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
input_shape, (batch_size, 1)
)
# required mask seq length can be calculated via length of past
# key value states and seq_length = 1 for the last token
mask_seq_length = shape_list(past_key_values[0][0])[2] + seq_length
else:
mask_seq_length = seq_length
mask_seq_length = (
shape_list(past_key_values[0][0])[2] + seq_length if past_key_values is not None else seq_length
)
if attention_mask is None:
attention_mask = tf.fill((batch_size, mask_seq_length), 1)
......@@ -692,7 +693,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if past_key_values[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
......@@ -740,7 +741,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs_embeds, training=training)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_values)):
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -752,7 +753,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
past_key_value_state=past_key_value_state,
past_key_value=past_key_value,
use_cache=use_cache,
output_attentions=output_attentions,
training=training,
......@@ -915,22 +916,19 @@ T5_INPUTS_DOCSTRING = r"""
- 0 for tokens that are **maked**.
`What are attention masks? <../glossary.html#attention-mask>`__
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
encoder_outputs (:obj:`tuple(tuple(tf.FloatTensor)`, `optional`):
Tuple consists of (:obj:`last_hidden_state`, :obj:`optional`: `hidden_states`, :obj:`optional`: `attentions`)
:obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)` is a sequence of
hidden states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`):
Default behavior: generate a tensor that ignores pad tokens in :obj:`decoder_input_ids`. Causal mask will
also be used by default.
past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
ontains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, ``past_key_values`` key value states are returned and can be used to speed up
decoding (see ``past_key_values``).
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
......@@ -944,7 +942,7 @@ T5_INPUTS_DOCSTRING = r"""
associated vectors than the model's internal embedding lookup matrix.
If :obj:`decoder_input_ids` and :obj:`decoder_inputs_embeds` are both
unset, :obj:`decoder_input_embeds` takes the value of :obj:`input_embeds`.
unset, :obj:`decoder_inputs_embeds` takes the value of :obj:`inputs_embeds`.
head_mask: (:obj:`tf.Tensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
......@@ -952,6 +950,9 @@ T5_INPUTS_DOCSTRING = r"""
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
decoding (see :obj:`past_key_values`).
output_attentions (:obj:`bool`, `optional`):
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
tensors for more detail.
......@@ -1017,12 +1018,12 @@ class TFT5Model(TFT5PreTrainedModel):
self,
inputs,
attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
head_mask=None,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
......@@ -1040,20 +1041,22 @@ class TFT5Model(TFT5PreTrainedModel):
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = TFT5Model.from_pretrained('t5-small')
>>> inputs = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1
>>> outputs = model(inputs, decoder_input_ids=inputs)
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
>>> input_ids = tokenizer("Studies have been shown that owning a dog is good for you", return_tensors="tf").input_ids # Batch size 1
>>> decoder_input_ids = tokenizer("Studies show that", return_tensors="tf").input_ids # Batch size 1
>>> outputs = model(input_ids, decoder_input_ids=decoder_input_ids, return_dict=True)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
head_mask = inputs[4] if len(inputs) > 4 else head_mask
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else head_mask
head_mask = inputs[6] if len(inputs) > 6 else head_mask
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
......@@ -1066,17 +1069,16 @@ class TFT5Model(TFT5PreTrainedModel):
input_ids = inputs.get("inputs")
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
past_key_values = inputs.get("past_key_values", past_key_values)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 13, "Too many inputs."
if "past_key_value_states" in inputs:
......@@ -1096,52 +1098,43 @@ class TFT5Model(TFT5PreTrainedModel):
past_key_values = kwargs.pop("past_key_value_states")
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
[
input_ids,
attention_mask,
None,
None,
inputs_embeds,
head_mask,
None,
False,
output_attentions,
output_hidden_states,
],
attention_mask=attention_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
past_key_values=None,
use_cache=False,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
hidden_states = encoder_outputs[0]
# If decoding with past key value states, only the last tokens
# should be given as an input
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Decode
decoder_outputs = self.decoder(
[
decoder_input_ids,
decoder_attention_mask,
hidden_states,
attention_mask,
decoder_inputs_embeds,
head_mask,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
],
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
past = (
(encoder_outputs, decoder_outputs[1]) if cast_bool_to_primitive(use_cache, self.config.use_cache) else None
)
......@@ -1150,12 +1143,6 @@ class TFT5Model(TFT5PreTrainedModel):
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + encoder_outputs
# If put before, this breaks the tf compilation.
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# This is long and annoying but if we introduce return_dict at the TFT5MainLayer level (like in PyTorch)
# TF refuses to compile anymore.
if not cast_bool_to_primitive(use_cache, self.config.use_cache):
......@@ -1227,18 +1214,18 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
self,
inputs,
attention_mask=None,
encoder_outputs=None,
inputs_embeds=None,
head_mask=None,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
encoder_outputs=None,
past_key_values=None,
head_mask=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
training=False,
**kwargs,
):
......@@ -1253,33 +1240,35 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
>>> from transformers import T5Tokenizer, TFT5ForConditionalGeneration
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small', return_dict=True)
>>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
>>> inputs = tokenizer.encode("Hello, my dog is cute", return_tensors="tf") # Batch size 1
>>> outputs = model(inputs, decoder_input_ids=inputs)
>>> prediction_scores = outputs[0]
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = TFT5ForConditionalGeneration.from_pretrained('t5-small')
>>> inputs = tokenizer.encode("summarize: Hello, my dog is cute", return_tensors="tf") # Batch size 1
>>> inputs = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='tf').input_ids
labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2> </s>', return_tensors='tf').input_ids
>>> outputs = model(inputs, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> inputs = tokenizer("summarize: studies have shown that owning a dog is good for you ", return_tensors="tf").input_ids # Batch size 1
>>> result = model.generate(inputs)
"""
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
attention_mask = inputs[1] if len(inputs) > 1 else attention_mask
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
head_mask = inputs[4] if len(inputs) > 4 else head_mask
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
decoder_input_ids = inputs[2] if len(inputs) > 2 else decoder_input_ids
decoder_attention_mask = inputs[3] if len(inputs) > 3 else decoder_attention_mask
encoder_outputs = inputs[4] if len(inputs) > 4 else encoder_outputs
past_key_values = inputs[5] if len(inputs) > 5 else head_mask
head_mask = inputs[6] if len(inputs) > 6 else head_mask
inputs_embeds = inputs[7] if len(inputs) > 7 else inputs_embeds
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
use_cache = inputs[9] if len(inputs) > 9 else use_cache
output_attentions = inputs[10] if len(inputs) > 10 else output_attentions
output_hidden_states = inputs[11] if len(inputs) > 11 else output_hidden_states
return_dict = inputs[12] if len(inputs) > 12 else return_dict
labels = inputs[13] if len(inputs) > 13 else labels
labels = inputs[9] if len(inputs) > 9 else labels
use_cache = inputs[10] if len(inputs) > 10 else use_cache
output_attentions = inputs[11] if len(inputs) > 11 else output_attentions
output_hidden_states = inputs[12] if len(inputs) > 12 else output_hidden_states
return_dict = inputs[13] if len(inputs) > 13 else return_dict
assert len(inputs) <= 14, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
if "inputs" in inputs:
......@@ -1287,18 +1276,18 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
input_ids = inputs.get("inputs")
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
past_key_values = inputs.get("past_key_values", past_key_values)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
past_key_values = inputs.get("past_key_values", past_key_values)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
labels = inputs.get("labels", labels)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 14, "Too many inputs."
if "past_key_value_states" in inputs:
......@@ -1318,24 +1307,19 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
past_key_values = kwargs.pop("past_key_value_states")
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
return_dict = return_dict if return_dict is not None else self.config.return_dict
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
# Convert encoder inputs in embeddings if needed
encoder_outputs = self.encoder(
[
input_ids,
attention_mask,
None,
None,
inputs_embeds,
head_mask,
None,
False,
output_attentions,
output_hidden_states,
],
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
......@@ -1355,18 +1339,16 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# Decode
decoder_outputs = self.decoder(
[
decoder_input_ids,
decoder_attention_mask,
hidden_states,
attention_mask,
decoder_inputs_embeds,
head_mask,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
],
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
inputs_embeds=decoder_inputs_embeds,
head_mask=head_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
training=training,
)
......@@ -1422,6 +1404,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
else:
encoder_outputs, past_key_values = past[0], past[1]
# cut decoder_input_ids if past is used
if past_key_values is not None:
inputs = inputs[:, -1:]
return {
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
......
......@@ -1065,7 +1065,7 @@ XLNET_INPUTS_DOCSTRING = r"""
decoding. The token ids which have their past given to this model should not be passed as
:obj:`input_ids` as they have already been computed.
:obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`.
:obj::obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`.
perm_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
......
......@@ -237,8 +237,15 @@ class ModuleUtilsMixin:
batch_size, seq_length = input_shape
seq_ids = torch.arange(seq_length, device=device)
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
if causal_mask.shape[1] < attention_mask.shape[1]:
prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
causal_mask = torch.cat(
[torch.ones((batch_size, seq_length, prefix_seq_len), device=device), causal_mask], axis=-1
)
# causal and attention masks must have same type with pytorch version < 1.3
causal_mask = causal_mask.to(attention_mask.dtype)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
......
......@@ -874,7 +874,7 @@ XLNET_INPUTS_DOCSTRING = r"""
decoding. The token ids which have their past given to this model should not be passed as
:obj:`input_ids` as they have already been computed.
:obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`.
:obj::obj:`use_cache` has to be set to :obj:`True` to make use of :obj:`mems`.
perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
......@@ -997,15 +997,15 @@ class XLNetModel(XLNetPreTrainedModel):
curr_out = curr_out[: self.reuse_len]
if self.mem_len is None or self.mem_len == 0:
# If `use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# If :obj:`use_cache` is active but no `mem_len` is defined, the model behaves like GPT-2 at inference time
# and returns all of the past and current hidden states.
cutoff = 0
else:
# If `use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# If :obj:`use_cache` is active and `mem_len` is defined, the model returns the last `mem_len` hidden
# states. This is the preferred setting for training and long-form generation.
cutoff = -self.mem_len
if prev_mem is None:
# if `use_cache` is active and `mem_len` is defined, the model
# if :obj:`use_cache` is active and `mem_len` is defined, the model
new_mem = curr_out[cutoff:]
else:
new_mem = torch.cat([prev_mem, curr_out], dim=0)[cutoff:]
......
......@@ -76,7 +76,7 @@ class ModelTester:
self.bos_token_id = 0
torch.manual_seed(0)
def prepare_config_and_inputs_for_common(self):
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).clamp(
3,
)
......@@ -101,6 +101,13 @@ class ModelTester:
inputs_dict = prepare_bart_inputs_dict(config, input_ids)
return config, inputs_dict
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
inputs_dict["decoder_input_ids"] = inputs_dict["input_ids"]
inputs_dict["decoder_attention_mask"] = inputs_dict["attention_mask"]
inputs_dict["use_cache"] = False
return config, inputs_dict
def prepare_bart_inputs_dict(
config,
......@@ -139,7 +146,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
self.config_tester.run_common_tests()
def test_initialization_more(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
model = BartModel(config)
model.to(torch_device)
model.eval()
......@@ -156,7 +163,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
_check_var(model.encoder.embed_positions)
def test_advanced_inputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
config.use_cache = False
inputs_dict["input_ids"][:, -2:] = config.pad_token_id
decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_bart_decoder_inputs(
......@@ -185,7 +192,7 @@ class BARTModelTest(ModelTesterMixin, unittest.TestCase):
_assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
def test_save_load_strict(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config, inputs_dict = self.model_tester.prepare_config_and_inputs()
for model_class in self.all_model_classes:
model = model_class(config)
......
......@@ -14,6 +14,7 @@
# limitations under the License.
import copy
import inspect
import os.path
import random
import tempfile
......@@ -158,6 +159,28 @@ class ModelTesterMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"encoder_outputs",
]
self.assertListEqual(arg_names[:5], expected_arg_names)
else:
expected_arg_names = ["input_ids"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
seq_len = getattr(self.model_tester, "seq_length", None)
......@@ -187,7 +210,7 @@ class ModelTesterMixin:
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
outputs = model(**self._prepare_for_class(inputs_dict, model_class), return_dict=True)
attentions = outputs[-1]
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
......@@ -272,10 +295,22 @@ class ModelTesterMixin:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
inputs = self._prepare_for_class(inputs_dict, model_class)["input_ids"] # Let's keep only input_ids
inputs = self._prepare_for_class(inputs_dict, model_class)
try:
traced_gpt2 = torch.jit.trace(model, inputs)
if model.config.is_encoder_decoder:
model.config.use_cache = False # TODO: this should be deleted after bug #7474 is solved
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
decoder_input_ids = inputs["decoder_input_ids"]
decoder_attention_mask = inputs["decoder_attention_mask"]
traced_model = torch.jit.trace(
model, (input_ids, attention_mask, decoder_input_ids, decoder_attention_mask)
)
else:
input_ids = inputs["input_ids"]
traced_model = torch.jit.trace(model, input_ids)
except RuntimeError:
self.fail("Couldn't trace module.")
......@@ -283,7 +318,7 @@ class ModelTesterMixin:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_gpt2, pt_file_name)
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment