"vscode:/vscode.git/clone" did not exist on "f50b82af04e52e3c8e04cd239b19474701cc526b"
Unverified Commit afc4ece4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Generate] Facilitate PyTorch generate using `ModelOutputs` (#6735)

* fix generate for GPT2 Double Head

* fix gpt2 double head model

* fix  bart / t5

* also add for no beam search

* fix no beam search

* fix encoder decoder

* simplify t5

* simplify t5

* fix t5 tests

* fix BART

* fix transfo-xl

* fix conflict

* integrating sylvains and sams comments

* fix tf past_decoder_key_values

* fix enc dec test
parent 397f8196
Encoder Decoder Models
------------------------
This class can wrap an encoder model, such as ``BertModel`` and a decoder modeling with a language modeling head, such as ``BertForMaskedLM`` into a encoder-decoder model.
The :class:`~transformers.EncoderDecoderModel` can be used to initialize a sequence-to-sequence model with any pre-trained autoencoding model as the encoder and any pre-trained autoregressive model as the decoder.
The ``EncoderDecoderModel`` class allows to instantiate a encoder decoder model using the ``from_encoder_decoder_pretrain`` class method taking a pretrained encoder and pretrained decoder model as an input.
The ``EncoderDecoderModel`` is saved using the standard ``save_pretrained()`` method and can also again be loaded using the standard ``from_pretrained()`` method.
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
An application of this architecture could be *summarization* using two pretrained Bert models as is shown in the paper: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
After such an :class:`~transformers.EncoderDecoderModel` has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
An application of this architecture could be to leverage two pre-trained :obj:`transformers.BertModel` models as the encoder and decoder for a summarization model as was shown in: `Text Summarization with Pretrained Encoders <https://arxiv.org/abs/1910.13461>`_ by Yang Liu and Mirella Lapata.
``EncoderDecoderConfig``
......
......@@ -20,6 +20,7 @@ import torch
from torch import Tensor
from torch.nn import functional as F
from .file_utils import ModelOutput
from .utils import logging
......@@ -46,14 +47,6 @@ class GenerationMixin:
"""
return logits
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
if len(outputs) <= 1 or use_cache is False:
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""
Enforce the repetition penalty (from the `CTRL paper <https://arxiv.org/abs/1909.05858>`__).
......@@ -137,7 +130,7 @@ class GenerationMixin:
attention_mask: Optional[torch.LongTensor] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
**model_specific_kwargs
**model_kwargs
) -> torch.LongTensor:
r"""
Generates sequences for models with a language modeling head. The method currently supports greedy decoding,
......@@ -208,7 +201,7 @@ class GenerationMixin:
use_cache: (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should use the past last key/values attentions (if applicable to the model) to
speed up decoding.
model_specific_kwargs:
model_kwargs:
Additional model specific kwargs will be forwarded to the :obj:`forward` function of the model.
Return:
......@@ -400,7 +393,7 @@ class GenerationMixin:
# get encoder and store encoder outputs
encoder = self.get_encoder()
encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
encoder_outputs: ModelOutput = encoder(input_ids, attention_mask=attention_mask, return_dict=True)
# Expand input ids if num_beams > 1 or num_return_sequences > 1
if num_return_sequences > 1 or num_beams > 1:
......@@ -428,8 +421,8 @@ class GenerationMixin:
cur_len = 1
assert (
batch_size == encoder_outputs[0].shape[0]
), f"expected encoder_outputs[0] to have 1st dimension bs={batch_size}, got {encoder_outputs[0].shape[0]} "
batch_size == encoder_outputs.last_hidden_state.shape[0]
), f"expected encoder_outputs.last_hidden_state to have 1st dimension bs={batch_size}, got {encoder_outputs.last_hidden_state.shape[0]} "
# expand batch_idx to assign correct encoder output for expanded input_ids (due to num_beams > 1 and num_return_sequences > 1)
expanded_batch_idxs = (
......@@ -439,11 +432,16 @@ class GenerationMixin:
.view(-1)
.to(input_ids.device)
)
# expand encoder_outputs
encoder_outputs = (encoder_outputs[0].index_select(0, expanded_batch_idxs), *encoder_outputs[1:])
encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
0, expanded_batch_idxs
)
# save encoder_outputs in `model_kwargs`
model_kwargs["encoder_outputs"] = encoder_outputs
else:
encoder_outputs = None
cur_len = input_ids.shape[-1]
assert (
......@@ -471,10 +469,9 @@ class GenerationMixin:
length_penalty=length_penalty,
num_beams=num_beams,
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
model_kwargs=model_kwargs,
)
else:
output = self._generate_no_beam_search(
......@@ -492,10 +489,9 @@ class GenerationMixin:
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
batch_size=effective_batch_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
model_specific_kwargs=model_specific_kwargs,
model_kwargs=model_kwargs,
)
return output
......@@ -516,10 +512,9 @@ class GenerationMixin:
pad_token_id,
eos_token_id,
batch_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
model_kwargs,
):
"""Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
......@@ -528,15 +523,14 @@ class GenerationMixin:
unfinished_sents = input_ids.new(batch_size).fill_(1)
sent_lengths = input_ids.new(batch_size).fill_(max_length)
past = (encoder_outputs, None) if encoder_outputs is not None else None
past = None
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
outputs = self(**model_inputs, return_dict=True)
next_token_logits = outputs.logits[:, -1, :]
scores = self.postprocess_next_token_scores(
scores=next_token_logits,
......@@ -553,8 +547,10 @@ class GenerationMixin:
)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if do_sample:
# Temperature (higher temperature => more likely to sample low probability tokens)
......@@ -621,10 +617,9 @@ class GenerationMixin:
length_penalty,
num_beams,
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
model_specific_kwargs,
model_kwargs,
):
"""Generate sequences for each example with beam search."""
......@@ -643,21 +638,24 @@ class GenerationMixin:
beam_scores = beam_scores.view(-1) # shape (batch_size * num_beams,)
# cache compute states
past = (encoder_outputs, None) if encoder_outputs is not None else None
past = None
# done sentences
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
outputs = self(**model_inputs, return_dict=True) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._use_cache(outputs, use_cache):
past = outputs[1]
if "past_key_values" in outputs:
past = outputs.past_key_values
elif "mems" in outputs:
past = outputs.mems
if self.config.is_encoder_decoder and do_sample is False:
# TODO (PVP) still a bit hacky here - there might be a better solution
next_token_logits = self.adjust_logits_during_generation(
......
......@@ -111,15 +111,15 @@ BART_INPUTS_DOCSTRING = r"""
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
If ``past_key_values`` are used, the user can optionally input only the last
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
``decoder_past_key_values``).
If `use_cache` is True, ``past_key_values`` are returned and can be used to speed up decoding (see
``past_key_values``).
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
......@@ -502,7 +502,7 @@ class BartDecoder(nn.Module):
encoder_padding_mask,
decoder_padding_mask,
decoder_causal_mask,
decoder_past_key_values=None,
past_key_values=None,
use_cache=False,
output_attentions=False,
output_hidden_states=False,
......@@ -519,7 +519,7 @@ class BartDecoder(nn.Module):
encoder_hidden_states: output from the encoder, used for
encoder-side attention
encoder_padding_mask: for ignoring pad tokens
decoder_past_key_values (dict or None): dictionary used for storing state during generation
past_key_values (dict or None): dictionary used for storing state during generation
Returns:
BaseModelOutputWithPast or tuple:
......@@ -530,10 +530,16 @@ class BartDecoder(nn.Module):
"""
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = unused.pop("decoder_cached_states")
past_key_values = unused.pop("decoder_cached_states")
if "decoder_past_key_values" in unused:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_past_key_values")
# check attention mask and invert
if encoder_padding_mask is not None:
......@@ -568,7 +574,7 @@ class BartDecoder(nn.Module):
if self.training and (dropout_probability < self.layerdrop):
continue
layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
layer_state = past_key_values[idx] if past_key_values is not None else None
x, layer_self_attn, layer_past = decoder_layer(
x,
......@@ -594,10 +600,7 @@ class BartDecoder(nn.Module):
x = x.transpose(0, 1)
encoder_hidden_states = encoder_hidden_states.transpose(0, 1)
if use_cache:
next_cache = ((encoder_hidden_states, encoder_padding_mask), next_decoder_cache)
else:
next_cache = None
next_cache = next_decoder_cache if use_cache else None
if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None)
......@@ -869,13 +872,19 @@ class BartModel(PretrainedBartModel):
decoder_input_ids=None,
encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_values")
if decoder_input_ids is None:
use_cache = False
......@@ -924,7 +933,7 @@ class BartModel(PretrainedBartModel):
attention_mask,
decoder_padding_mask,
decoder_causal_mask=causal_mask,
decoder_past_key_values=decoder_past_key_values,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -936,7 +945,7 @@ class BartModel(PretrainedBartModel):
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_past_key_values=decoder_outputs.past_key_values,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
......@@ -994,7 +1003,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
labels=None,
use_cache=None,
output_attentions=None,
......@@ -1037,10 +1046,16 @@ class BartForConditionalGeneration(PretrainedBartModel):
labels = unused.pop("lm_labels")
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = unused.pop("decoder_cached_states")
past_key_values = unused.pop("decoder_cached_states")
if "decoder_past_key_values" in unused:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = unused.pop("decoder_past_key_values")
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
......@@ -1054,7 +1069,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
decoder_past_key_values=decoder_past_key_values,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
......@@ -1075,7 +1090,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return Seq2SeqLMOutput(
loss=masked_lm_loss,
logits=lm_logits,
decoder_past_key_values=outputs.decoder_past_key_values,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
......@@ -1083,14 +1098,13 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_attentions=outputs.encoder_attentions,
)
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
encoder_outputs, decoder_past_key_values = past
def prepare_inputs_for_generation(
self, decoder_input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs
):
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"decoder_past_key_values": decoder_past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
......@@ -1109,20 +1123,14 @@ class BartForConditionalGeneration(PretrainedBartModel):
@staticmethod
def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_past_key_values) = past
reordered_past = []
for layer_past in decoder_past_key_values:
for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
}
reordered_past.append(layer_past_new)
new_enc_out = enc_out if enc_out is None else enc_out.index_select(0, beam_idx)
new_enc_mask = enc_mask if enc_mask is None else enc_mask.index_select(0, beam_idx)
past = ((new_enc_out, new_enc_mask), reordered_past)
return past
return reordered_past
def get_encoder(self):
return self.model.encoder
......@@ -1208,7 +1216,7 @@ class BartForSequenceClassification(PretrainedBartModel):
return Seq2SeqSequenceClassifierOutput(
loss=loss,
logits=logits,
decoder_past_key_values=outputs.decoder_past_key_values,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
......@@ -1316,7 +1324,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
loss=total_loss,
start_logits=start_logits,
end_logits=end_logits,
decoder_past_key_values=outputs.decoder_past_key_values,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
......
......@@ -19,13 +19,79 @@ from typing import Optional
from .configuration_encoder_decoder import EncoderDecoderConfig
from .configuration_utils import PretrainedConfig
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable, replace_return_docstrings
from .modeling_outputs import Seq2SeqLMOutput
from .modeling_utils import PreTrainedModel
from .utils import logging
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "EncoderDecoderConfig"
ENCODER_DECODER_START_DOCSTRING = r"""
This class can be used to inialize a sequence-to-sequnece model with any pretrained autoencoding model as the encoder and any pretrained autoregressive model as the decoder. The encoder is loaded via :meth:`~transformers.AutoModel.from_pretrained` function and the decoder is loaded via :meth:`~transformers.AutoModelForCausalLM.from_pretrained` function.
Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream generative task, *i.e.* summarization.
The effectiveness of initializing sequence-to-sequence models with pre-trained checkpoints for sequence generation tasks was shown in `Leveraging Pre-trained Checkpoints for Sequence Generation Tasks <https://arxiv.org/abs/1907.12461>`__ by Sascha Rothe, Shashi Narayan, Aliaksei Severyn.
Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
After such an Encoder Decoder model has been trained / fine-tuned, it can be saved / loaded just like any other models (see Examples for more information).
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#module>`__ sub-class. Use it as a
regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
Parameters:
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
ENCODER_DECODER_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the encoder.
Indices can be obtained using :class:`~transformers.PretrainedTokenizer`.
See :meth:`~transformers.PreTrainedTokenizer.encode` and
:meth:`~transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices for the encoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
This tuple must consist of (:obj:`last_hidden_state`, `optional`: :obj:`hidden_states`, `optional`: :obj:`attentions`)
`last_hidden_state` (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`) is a tensor of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training to the decoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the model will return a :class:`~transformers.file_utils.Seq2SeqLMOutput` instead of a
plain tuple.
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as ``**encoder_kwargs`` for the encoder forward function.
- With a `decoder_` prefix which will be input as ``**decoder_kwargs`` for the decoder forward function.
"""
@add_start_docstrings(ENCODER_DECODER_START_DOCSTRING)
class EncoderDecoderModel(PreTrainedModel):
r"""
:class:`~transformers.EncoderDecoder` is a generic model class that will be
......@@ -206,6 +272,8 @@ class EncoderDecoderModel(PreTrainedModel):
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
return cls(encoder=encoder, decoder=decoder, config=config)
@add_start_docstrings_to_callable(ENCODER_DECODER_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
......@@ -216,47 +284,11 @@ class EncoderDecoderModel(PreTrainedModel):
decoder_attention_mask=None,
decoder_inputs_embeds=None,
labels=None,
return_dict=None,
**kwargs,
):
"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary for the encoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices for the encoder.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
encoder_outputs (:obj:`tuple(tuple(torch.FloatTensor)`, `optional`, defaults to :obj:`None`):
Tuple consists of (`last_hidden_state`, `optional`: `hidden_states`, `optional`: `attentions`)
`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`) is a sequence of hidden-states at the output of the last layer of the encoder.
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training to the decoder.
Indices can be obtained using :class:`transformers.PretrainedTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Labels for computing the masked language modeling loss for the decoder.
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
kwargs: (`optional`) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors:
- Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function.
- With a `decoder_` prefix which will be input as `**decoder_kwargs` for the decoder forward function.
r"""
Returns:
Examples::
......@@ -264,19 +296,25 @@ class EncoderDecoderModel(PreTrainedModel):
>>> import torch
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
>>> # forward
>>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids)
>>> # training
>>> loss, outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids)[:2]
>>> outputs = model(input_ids=input_ids, decoder_input_ids=input_ids, labels=input_ids, return_dict=True)
>>> loss, logits = outputs.loss, outputs.logits
>>> # save and load from pretrained
>>> model.save_pretrained("bert2bert")
>>> model = EncoderDecoderModel.from_pretrained("bert2bert")
>>> # generation
>>> generated = model.generate(input_ids, decoder_start_token_id=model.config.decoder.pad_token_id)
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
kwargs_encoder = {argument: value for argument, value in kwargs.items() if not argument.startswith("decoder_")}
......@@ -289,7 +327,7 @@ class EncoderDecoderModel(PreTrainedModel):
input_ids=input_ids,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
return_dict=False,
return_dict=return_dict,
**kwargs_encoder,
)
......@@ -303,23 +341,28 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
labels=labels,
return_dict=False,
return_dict=return_dict,
**kwargs_decoder,
)
# TODO(PVP): currently it is not possible to use `past`
# with the encoder/decoder framework -> should be implemented
if not return_dict:
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
return Seq2SeqLMOutput(
loss=decoder_outputs.loss,
logits=decoder_outputs.logits,
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions,
)
# first step
if type(past) is tuple:
encoder_outputs, _ = past
else:
encoder_outputs = (past,)
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, encoder_outputs, **kwargs):
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
......@@ -335,7 +378,7 @@ class EncoderDecoderModel(PreTrainedModel):
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
if "past_key_values" in decoder_inputs:
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
input_dict["past_key_values"] = decoder_inputs["past_key_values"]
return input_dict
......
......@@ -353,11 +353,11 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
......@@ -380,9 +380,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_loss: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
mc_loss: Optional[torch.FloatTensor] = None
lm_logits: torch.FloatTensor = None
logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
......@@ -777,6 +777,17 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=GPT2DoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
......@@ -893,9 +904,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output
return GPT2DoubleHeadsModelOutput(
lm_loss=lm_loss,
loss=lm_loss,
mc_loss=mc_loss,
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
......
......@@ -300,11 +300,11 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
Language modeling loss.
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
Multiple choice classification loss.
lm_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
......@@ -321,9 +321,9 @@ class OpenAIGPTDoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_loss: Optional[torch.FloatTensor] = None
loss: Optional[torch.FloatTensor] = None
mc_loss: Optional[torch.FloatTensor] = None
lm_logits: torch.FloatTensor = None
logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -713,9 +713,9 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
return ((lm_loss,) + output) if lm_loss is not None else output
return OpenAIGPTDoubleHeadsModelOutput(
lm_loss=lm_loss,
loss=lm_loss,
mc_loss=mc_loss,
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
......
......@@ -109,13 +109,13 @@ class Seq2SeqModelOutput(ModelOutput):
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -143,7 +143,7 @@ class Seq2SeqModelOutput(ModelOutput):
"""
last_hidden_state: torch.FloatTensor
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
......@@ -255,12 +255,12 @@ class Seq2SeqLMOutput(ModelOutput):
Languaged modeling loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -289,7 +289,7 @@ class Seq2SeqLMOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
......@@ -365,12 +365,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -399,7 +399,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
......@@ -511,12 +511,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
Span-start scores (before SoftMax).
end_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
decoder_past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -546,7 +546,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
start_logits: torch.FloatTensor = None
end_logits: torch.FloatTensor = None
decoder_past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
......
......@@ -838,27 +838,27 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`).
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_input_ids takes the value of input_ids.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`).
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`).
If `past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `past_key_values`).
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_inputs_embeds takes the value of inputs_embeds.
......@@ -928,7 +928,7 @@ class T5Model(T5PreTrainedModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
use_cache=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
......@@ -955,10 +955,16 @@ class T5Model(T5PreTrainedModel):
"""
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
past_key_values = kwargs.pop("decoder_past_key_value_states")
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_values")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache
......@@ -992,7 +998,7 @@ class T5Model(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_values is not None:
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
......@@ -1003,7 +1009,7 @@ class T5Model(T5PreTrainedModel):
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_values,
past_key_value_states=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
......@@ -1013,15 +1019,12 @@ class T5Model(T5PreTrainedModel):
return_dict=return_dict,
)
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
if not return_dict:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
decoder_past_key_values=past,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
......@@ -1080,7 +1083,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_outputs=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_values=None,
past_key_values=None,
use_cache=None,
labels=None,
inputs_embeds=None,
......@@ -1127,10 +1130,16 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
labels = kwargs.pop("lm_labels")
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("decoder_past_key_value_states")
if "decoder_past_key_values" in kwargs:
warnings.warn(
"The `decoder_past_key_values` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
past_key_values = kwargs.pop("decoder_past_key_values")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache
......@@ -1163,7 +1172,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_values is not None:
if past_key_values is not None:
assert labels is None, "Decoder should not use cached key value states when training."
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
......@@ -1175,7 +1184,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_values,
past_key_value_states=past_key_values,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
......@@ -1197,17 +1206,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
past = (encoder_outputs, decoder_outputs[1]) if use_cache is True else None
if not return_dict:
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
decoder_past_key_values=past,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
......@@ -1215,14 +1221,10 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_attentions=encoder_outputs.attentions,
)
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
encoder_outputs, decoder_past_key_values = past
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, encoder_outputs, **kwargs):
return {
"decoder_input_ids": input_ids,
"decoder_past_key_values": decoder_past_key_values,
"past_key_values": past,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"use_cache": use_cache,
......@@ -1231,14 +1233,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if past[1] is None:
if past is None:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
decoder_past = past[1]
past = (past[0],)
reordered_decoder_past = ()
for layer_past_states in decoder_past:
for layer_past_states in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
......@@ -1252,4 +1252,4 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return past + (reordered_decoder_past,)
return reordered_decoder_past
......@@ -431,7 +431,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
......@@ -454,7 +454,7 @@ class TFGPT2DoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_logits: tf.Tensor = None
logits: tf.Tensor = None
mc_logits: tf.Tensor = None
past_key_values: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
......@@ -794,7 +794,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
return (lm_logits, mc_logits) + transformer_outputs[1:]
return TFGPT2DoubleHeadsModelOutput(
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states,
......
......@@ -394,7 +394,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
Base class for outputs of models predicting if two sentences are consecutive or not.
Args:
lm_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
......@@ -411,7 +411,7 @@ class TFOpenAIGPTDoubleHeadsModelOutput(ModelOutput):
heads.
"""
lm_logits: tf.Tensor = None
logits: tf.Tensor = None
mc_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None
......@@ -719,7 +719,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
return (lm_logits, mc_logits) + transformer_outputs[1:]
return TFOpenAIGPTDoubleHeadsModelOutput(
lm_logits=lm_logits,
logits=lm_logits,
mc_logits=mc_logits,
hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions,
......
......@@ -113,13 +113,13 @@ class TFSeq2SeqModelOutput(ModelOutput):
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the decoder of the model.
If ``decoder_past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
If ``past_key_values`` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -147,7 +147,7 @@ class TFSeq2SeqModelOutput(ModelOutput):
"""
last_hidden_state: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
......@@ -259,12 +259,12 @@ class TFSeq2SeqLMOutput(ModelOutput):
Languaged modeling loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -293,7 +293,7 @@ class TFSeq2SeqLMOutput(ModelOutput):
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
......@@ -366,12 +366,12 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -400,7 +400,7 @@ class TFSeq2SeqSequenceClassifierOutput(ModelOutput):
loss: Optional[tf.Tensor] = None
logits: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
......@@ -512,12 +512,12 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
Span-start scores (before SoftMax).
end_logits (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
decoder_past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
past_key_values (:obj:`List[tf.Tensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`tf.Tensor` of length :obj:`config.n_layers`, with each tensor of shape
:obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
used (see ``decoder_past_key_values`` input) to speed up sequential decoding.
used (see ``past_key_values`` input) to speed up sequential decoding.
decoder_hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -547,7 +547,7 @@ class TFSeq2SeqQuestionAnsweringModelOutput(ModelOutput):
loss: Optional[tf.Tensor] = None
start_logits: tf.Tensor = None
end_logits: tf.Tensor = None
decoder_past_key_values: Optional[List[tf.Tensor]] = None
past_key_values: Optional[List[tf.Tensor]] = None
decoder_hidden_states: Optional[Tuple[tf.Tensor]] = None
decoder_attentions: Optional[Tuple[tf.Tensor]] = None
encoder_last_hidden_state: Optional[tf.Tensor] = None
......
......@@ -437,15 +437,15 @@ class TFT5Block(tf.keras.layers.Layer):
):
if past_key_value_state is not None:
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
assert self.is_decoder, "Only decoder can use `past_key_values`"
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
expected_num_past_key_value_states,
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
expected_num_past_key_values,
"2 (past / key) for cross attention" if expected_num_past_key_values == 4 else "",
len(past_key_value_state),
)
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
assert len(past_key_value_state) == expected_num_past_key_values, error_message
self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:]
......@@ -586,11 +586,12 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
past_key_value_states=None,
past_key_values=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
training=False,
**kwargs,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
......@@ -599,7 +600,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask = inputs[3] if len(inputs) > 3 else encoder_attention_mask
inputs_embeds = inputs[4] if len(inputs) > 4 else inputs_embeds
head_mask = inputs[5] if len(inputs) > 5 else head_mask
past_key_value_states = inputs[6] if len(inputs) > 6 else past_key_value_states
past_key_values = inputs[6] if len(inputs) > 6 else past_key_values
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
......@@ -611,13 +612,26 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask = inputs.get("encoder_attention_mask", encoder_attention_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
past_key_value_states = inputs.get("past_key_value_states", past_key_value_states)
past_key_values = inputs.get("past_key_values", past_key_values)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
if "past_key_value_states" in inputs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = inputs.pop("past_key_value_states")
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
......@@ -639,13 +653,13 @@ class TFT5MainLayer(tf.keras.layers.Layer):
batch_size, seq_length = input_shape
if past_key_value_states is not None:
if past_key_values is not None:
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
input_shape, (batch_size, 1)
)
# required mask seq length can be calculated via length of past
# key value states and seq_length = 1 for the last token
mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length
mask_seq_length = shape_list(past_key_values[0][0])[2] + seq_length
else:
mask_seq_length = seq_length
......@@ -655,9 +669,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_seq_length = shape_list(encoder_hidden_states)[1]
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_value_states with `None` if past does not exist
if past_key_value_states is None:
past_key_value_states = [None] * len(self.block)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
......@@ -677,7 +691,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
)
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if past_key_value_states[0] is not None:
if past_key_values[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
......@@ -726,7 +740,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
hidden_states = self.dropout(inputs_embeds, training=training)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_values)):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -878,7 +892,7 @@ T5_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
......@@ -889,13 +903,13 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
past_key_values (:obj:`tuple(tuple(tf.Tensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
If `use_cache` is True, `past_key_values` are returned and can be used to speed up decoding (see `past_key_values`).
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`inputs` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `inputs` indices into associated vectors
......@@ -969,7 +983,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_outputs=None,
inputs_embeds=None,
head_mask=None,
decoder_past_key_value_states=None,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_inputs_embeds=None,
......@@ -978,6 +992,7 @@ class TFT5Model(TFT5PreTrainedModel):
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
Returns:
......@@ -999,7 +1014,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
head_mask = inputs[4] if len(inputs) > 4 else head_mask
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
......@@ -1017,7 +1032,7 @@ class TFT5Model(TFT5PreTrainedModel):
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
past_key_values = inputs.get("past_key_values", past_key_values)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
......@@ -1026,9 +1041,23 @@ class TFT5Model(TFT5PreTrainedModel):
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
return_dict = inputs.get("return_dict", return_dict)
assert len(inputs) <= 13, "Too many inputs."
if "past_key_value_states" in inputs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = inputs.pop("past_key_value_states")
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.return_dict
......@@ -1054,7 +1083,7 @@ class TFT5Model(TFT5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None:
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
......@@ -1069,7 +1098,7 @@ class TFT5Model(TFT5PreTrainedModel):
attention_mask,
decoder_inputs_embeds,
head_mask,
decoder_past_key_value_states,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
......@@ -1103,7 +1132,7 @@ class TFT5Model(TFT5PreTrainedModel):
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs[0],
decoder_past_key_values=past,
past_key_values=past,
decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=encoder_outputs[0],
......@@ -1164,7 +1193,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs=None,
inputs_embeds=None,
head_mask=None,
decoder_past_key_value_states=None,
past_key_values=None,
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_inputs_embeds=None,
......@@ -1174,6 +1203,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
return_dict=None,
labels=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -1204,7 +1234,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs = inputs[2] if len(inputs) > 2 else encoder_outputs
inputs_embeds = inputs[3] if len(inputs) > 3 else inputs_embeds
head_mask = inputs[4] if len(inputs) > 4 else head_mask
decoder_past_key_value_states = inputs[5] if len(inputs) > 5 else decoder_past_key_value_states
past_key_values = inputs[5] if len(inputs) > 5 else past_key_values
decoder_input_ids = inputs[6] if len(inputs) > 6 else decoder_input_ids
decoder_attention_mask = inputs[7] if len(inputs) > 7 else decoder_attention_mask
decoder_inputs_embeds = inputs[8] if len(inputs) > 8 else decoder_inputs_embeds
......@@ -1223,7 +1253,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
encoder_outputs = inputs.get("encoder_outputs", encoder_outputs)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
head_mask = inputs.get("head_mask", head_mask)
decoder_past_key_value_states = inputs.get("past_key_value_states", decoder_past_key_value_states)
past_key_values = inputs.get("past_key_values", past_key_values)
decoder_input_ids = inputs.get("decoder_input_ids", decoder_input_ids)
decoder_attention_mask = inputs.get("decoder_attention_mask", decoder_attention_mask)
decoder_inputs_embeds = inputs.get("decoder_inputs_embeds", decoder_inputs_embeds)
......@@ -1233,9 +1263,23 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
return_dict = inputs.get("return_dict", return_dict)
labels = inputs.get("labels", labels)
assert len(inputs) <= 14, "Too many inputs."
if "past_key_value_states" in inputs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = inputs.pop("past_key_value_states")
else:
input_ids = inputs
if "past_key_value_states" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.return_dict
......@@ -1266,7 +1310,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None:
if past_key_values is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
......@@ -1281,7 +1325,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attention_mask,
decoder_inputs_embeds,
head_mask,
decoder_past_key_value_states,
past_key_values,
use_cache,
output_attentions,
output_hidden_states,
......@@ -1324,7 +1368,7 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
return TFSeq2SeqLMOutput(
loss=loss,
logits=logits,
decoder_past_key_values=past,
past_key_values=past,
decoder_hidden_states=decoder_outputs[2],
decoder_attentions=decoder_outputs[3],
encoder_last_hidden_state=encoder_outputs[0],
......@@ -1337,14 +1381,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
# first step
if len(past) < 2:
encoder_outputs, decoder_past_key_value_states = past, None
encoder_outputs, past_key_values = past, None
else:
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
encoder_outputs, past_key_values = past[0], past[1]
return {
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": inputs, # inputs are the decoder_input_ids
"decoder_past_key_value_states": decoder_past_key_value_states,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"use_cache": use_cache,
......
......@@ -661,6 +661,15 @@ class TransfoXLLMHeadModelOutput(ModelOutput):
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@property
def logits(self):
# prediciton scores are the output of the adaptive softmax, see
# the file `modeling_transfo_xl_utilities`. Since the adaptive
# softmax returns the log softmax value, `self.prediciton_scores`
# are strictly speaking not exactly `logits`, but behave the same
# way logits do.
return self.prediction_scores
TRANSFO_XL_START_DOCSTRING = r"""
......
......@@ -33,6 +33,7 @@ if is_torch_available():
from transformers import (
BertLMHeadModel,
BertModel,
BertTokenizer,
EncoderDecoderConfig,
EncoderDecoderModel,
GPT2LMHeadModel,
......@@ -128,10 +129,11 @@ class EncoderDecoderMixin:
decoder_config,
decoder_input_ids,
decoder_attention_mask,
return_dict,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model}
kwargs = {"encoder_model": encoder_model, "decoder_model": decoder_model, "return_dict": return_dict}
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(**kwargs)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
......@@ -361,7 +363,11 @@ class EncoderDecoderMixin:
def test_encoder_decoder_model_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict)
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=False)
def test_encoder_decoder_model_from_pretrained_return_dict(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained(**input_ids_dict, return_dict=True)
def test_save_and_load_from_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
......@@ -466,6 +472,22 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
"labels": decoder_token_labels,
}
@slow
def test_bert2bert_summarization(self):
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
model.to(torch_device)
tokenizer = BertTokenizer.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
ARTICLE = """(CNN)Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members singing a racist chant. SAE's national chapter suspended the students, but University of Oklahoma President David Boren took it a step further, saying the university's affiliation with the fraternity is permanently done. The news is shocking, but it's not the first time SAE has faced controversy. SAE was founded March 9, 1856, at the University of Alabama, five years before the American Civil War, according to the fraternity website. When the war began, the group had fewer than 400 members, of which "369 went to war for the Confederate States and seven for the Union Army," the website says. The fraternity now boasts more than 200,000 living alumni, along with about 15,000 undergraduates populating 219 chapters and 20 "colonies" seeking full membership at universities. SAE has had to work hard to change recently after a string of member deaths, many blamed on the hazing of new recruits, SAE national President Bradley Cohen wrote in a message on the fraternity's website. The fraternity's website lists more than 130 chapters cited or suspended for "health and safety incidents" since 2010. At least 30 of the incidents involved hazing, and dozens more involved alcohol. However, the list is missing numerous incidents from recent months. Among them, according to various media outlets: Yale University banned the SAEs from campus activities last month after members allegedly tried to interfere with a sexual misconduct investigation connected to an initiation rite. Stanford University in December suspended SAE housing privileges after finding sorority members attending a fraternity function were subjected to graphic sexual content. And Johns Hopkins University in November suspended the fraternity for underage drinking. "The media has labeled us as the 'nation's deadliest fraternity,' " Cohen said. In 2011, for example, a student died while being coerced into excessive alcohol consumption, according to a lawsuit. SAE's previous insurer dumped the fraternity. "As a result, we are paying Lloyd's of London the highest insurance rates in the Greek-letter world," Cohen said. Universities have turned down SAE's attempts to open new chapters, and the fraternity had to close 12 in 18 months over hazing incidents."""
EXPECTED_SUMMARY = """sae was founded in 1856, five years before the civil war. the fraternity has had to work hard to change recently. the university of oklahoma president says the university's affiliation with the fraternity is permanently done. the sae has had a string of members in recent months."""
input_ids = tokenizer(ARTICLE, return_tensors="pt").input_ids.to(torch_device)
output_ids = model.generate(input_ids)
summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)
self.assertEqual(summary, EXPECTED_SUMMARY)
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
......
......@@ -289,9 +289,9 @@ class GPT2ModelTester:
}
result = model(**inputs)
self.parent.assertEqual(result.lm_loss.shape, ())
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
)
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
......@@ -324,7 +324,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (GPT2Model, GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
all_generative_model_classes = (
(GPT2LMHeadModel,) if is_torch_available() else ()
(GPT2LMHeadModel, GPT2DoubleHeadsModel) if is_torch_available() else ()
) # TODO (PVP): Add Double HeadsModel when generate() function is changed accordingly
test_missing_keys = False
......
......@@ -131,8 +131,8 @@ class OpenAIGPTModelTester:
model.eval()
result = model(input_ids, token_type_ids=token_type_ids, labels=input_ids)
self.parent.assertEqual(result.lm_loss.shape, ())
self.parent.assertEqual(result.lm_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.loss.shape, ())
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
......
......@@ -159,17 +159,15 @@ class T5ModelTester:
)
result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_output = result.last_hidden_state
decoder_past = result.decoder_past_key_values
decoder_past = result.past_key_values
encoder_output = result.encoder_last_hidden_state
self.parent.assertEqual(encoder_output.size(), (self.batch_size, self.encoder_seq_length, self.hidden_size))
self.parent.assertEqual(decoder_output.size(), (self.batch_size, self.decoder_seq_length, self.hidden_size))
self.parent.assertEqual(len(decoder_past), 2)
self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
# There should be `num_layers` key value embeddings stored in decoder_past
self.parent.assertEqual(len(decoder_past), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past tuple
self.parent.assertEqual(len(decoder_past[0]), 4)
def create_and_check_with_lm_head(
self,
......
......@@ -238,7 +238,7 @@ class TFGPT2ModelTester:
}
result = model(inputs)
self.parent.assertEqual(
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
)
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
......
......@@ -151,7 +151,7 @@ class TFOpenAIGPTModelTester:
}
result = model(inputs)
self.parent.assertEqual(
result.lm_logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
result.logits.shape, (self.batch_size, self.num_choices, self.seq_length, self.vocab_size)
)
self.parent.assertEqual(result.mc_logits.shape, (self.batch_size, self.num_choices))
......
......@@ -96,7 +96,7 @@ class TFT5ModelTester:
result = model(input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids)
decoder_output = result.last_hidden_state
decoder_past = result.decoder_past_key_values
decoder_past = result.past_key_values
encoder_output = result.encoder_last_hidden_state
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment