Unverified Commit df983b74 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Deprecate old past arguments (#5671)

parent cdf4cd70
...@@ -690,7 +690,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel): ...@@ -690,7 +690,7 @@ class AlbertForPreTraining(AlbertPreTrainedModel):
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("masked_lm_labels") labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
......
...@@ -111,6 +111,15 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -111,6 +111,15 @@ BART_INPUTS_DOCSTRING = r"""
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify. If you want to change padding behavior, you should read :func:`~transformers.modeling_bart._prepare_decoder_inputs` and modify.
See diagram 1 in the paper for more info on the default strategy See diagram 1 in the paper for more info on the default strategy
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If ``decoder_past_key_value_states`` are used, the user can optionally input only the last
``decoder_input_ids`` (those that don't have their past key value states given to this model) of shape
:obj:`(batch_size, 1)` instead of all ``decoder_input_ids`` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, ``decoder_past_key_values`` are returned and can be used to speed up decoding (see
``decoder_past_key_values``).
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
...@@ -482,7 +491,7 @@ class BartDecoder(nn.Module): ...@@ -482,7 +491,7 @@ class BartDecoder(nn.Module):
encoder_padding_mask, encoder_padding_mask,
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask, decoder_causal_mask,
decoder_cached_states=None, decoder_past_key_values=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
output_hidden_states=False, output_hidden_states=False,
...@@ -499,7 +508,7 @@ class BartDecoder(nn.Module): ...@@ -499,7 +508,7 @@ class BartDecoder(nn.Module):
encoder_hidden_states: output from the encoder, used for encoder_hidden_states: output from the encoder, used for
encoder-side attention encoder-side attention
encoder_padding_mask: for ignoring pad tokens encoder_padding_mask: for ignoring pad tokens
decoder_cached_states (dict or None): dictionary used for storing state during generation decoder_past_key_values (dict or None): dictionary used for storing state during generation
Returns: Returns:
BaseModelOutputWithPast or tuple: BaseModelOutputWithPast or tuple:
...@@ -508,6 +517,13 @@ class BartDecoder(nn.Module): ...@@ -508,6 +517,13 @@ class BartDecoder(nn.Module):
- hidden states - hidden states
- attentions - attentions
""" """
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = unused.pop("decoder_cached_states")
# check attention mask and invert # check attention mask and invert
if encoder_padding_mask is not None: if encoder_padding_mask is not None:
encoder_padding_mask = invert_mask(encoder_padding_mask) encoder_padding_mask = invert_mask(encoder_padding_mask)
...@@ -541,7 +557,7 @@ class BartDecoder(nn.Module): ...@@ -541,7 +557,7 @@ class BartDecoder(nn.Module):
if self.training and (dropout_probability < self.layerdrop): if self.training and (dropout_probability < self.layerdrop):
continue continue
layer_state = decoder_cached_states[idx] if decoder_cached_states is not None else None layer_state = decoder_past_key_values[idx] if decoder_past_key_values is not None else None
x, layer_self_attn, layer_past = decoder_layer( x, layer_self_attn, layer_past = decoder_layer(
x, x,
...@@ -854,11 +870,12 @@ class BartModel(PretrainedBartModel): ...@@ -854,11 +870,12 @@ class BartModel(PretrainedBartModel):
decoder_input_ids=None, decoder_input_ids=None,
encoder_outputs: Optional[Tuple] = None, encoder_outputs: Optional[Tuple] = None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_past_key_values=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs,
): ):
if decoder_input_ids is None: if decoder_input_ids is None:
...@@ -908,7 +925,7 @@ class BartModel(PretrainedBartModel): ...@@ -908,7 +925,7 @@ class BartModel(PretrainedBartModel):
attention_mask, attention_mask,
decoder_padding_mask, decoder_padding_mask,
decoder_causal_mask=causal_mask, decoder_causal_mask=causal_mask,
decoder_cached_states=decoder_cached_states, decoder_past_key_values=decoder_past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -977,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -977,7 +994,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
encoder_outputs=None, encoder_outputs=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_cached_states=None, decoder_past_key_values=None,
labels=None, labels=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -1015,9 +1032,15 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1015,9 +1032,15 @@ class BartForConditionalGeneration(PretrainedBartModel):
if "lm_labels" in unused: if "lm_labels" in unused:
warnings.warn( warnings.warn(
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = unused.pop("lm_labels") labels = unused.pop("lm_labels")
if "decoder_cached_states" in unused:
warnings.warn(
"The `decoder_cached_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = unused.pop("decoder_cached_states")
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
if labels is not None: if labels is not None:
...@@ -1029,7 +1052,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1029,7 +1052,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
decoder_cached_states=decoder_cached_states, decoder_past_key_values=decoder_past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
...@@ -1061,11 +1084,11 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1061,11 +1084,11 @@ class BartForConditionalGeneration(PretrainedBartModel):
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs): def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs" assert past is not None, "past has to be defined for encoder_outputs"
encoder_outputs, decoder_cached_states = past encoder_outputs, decoder_past_key_values = past
return { return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed "input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"decoder_cached_states": decoder_cached_states, "decoder_past_key_values": decoder_past_key_values,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
...@@ -1092,9 +1115,9 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1092,9 +1115,9 @@ class BartForConditionalGeneration(PretrainedBartModel):
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
((enc_out, enc_mask), decoder_cached_states) = past ((enc_out, enc_mask), decoder_past_key_values) = past
reordered_past = [] reordered_past = []
for layer_past in decoder_cached_states: for layer_past in decoder_past_key_values:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn # get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = { layer_past_new = {
attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items()
......
...@@ -879,7 +879,7 @@ class BertForPreTraining(BertPreTrainedModel): ...@@ -879,7 +879,7 @@ class BertForPreTraining(BertPreTrainedModel):
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("masked_lm_labels") labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
...@@ -1076,7 +1076,7 @@ class BertForMaskedLM(BertPreTrainedModel): ...@@ -1076,7 +1076,7 @@ class BertForMaskedLM(BertPreTrainedModel):
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("masked_lm_labels") labels = kwargs.pop("masked_lm_labels")
assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task." assert "lm_labels" not in kwargs, "Use `BertWithLMHead` for autoregressive language modeling task."
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import logging import logging
import warnings
import numpy as np import numpy as np
import torch import torch
...@@ -246,20 +247,22 @@ CTRL_START_DOCSTRING = r""" ...@@ -246,20 +247,22 @@ CTRL_START_DOCSTRING = r"""
CTRL_INPUTS_DOCSTRING = r""" CTRL_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states). :obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
If `past` is used, only input_ids that do not have their past calculated should be passed as input_ids. If ``past_key_values`` is used, only input_ids that do not have their past calculated should be passed as
``input_ids``.
Indices can be obtained using :class:`transformers.CTRLTokenizer`. Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.__call__` for details. :func:`transformers.PreTrainedTokenizer.__call__` for details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. (see ``past_key_values`` output below). Can be used to speed up sequential decoding.
The input_ids which have their past given to this model should not be passed as input ids as they have already been computed. The ``input_ids`` which have their past given to this model should not be passed as input ids as they have already been computed.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -284,10 +287,10 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -284,10 +287,10 @@ CTRL_INPUTS_DOCSTRING = r"""
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`). If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
use_cache (:obj:`bool`): use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and If `use_cache` is True, ``past_key_values`` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`. can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
...@@ -343,7 +346,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -343,7 +346,7 @@ class CTRLModel(CTRLPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
past=None, past_key_values=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -353,7 +356,16 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -353,7 +356,16 @@ class CTRLModel(CTRLPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs,
): ):
if "past" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_hidden_states = ( output_hidden_states = (
...@@ -373,11 +385,11 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -373,11 +385,11 @@ class CTRLModel(CTRLPreTrainedModel):
else: else:
raise ValueError("You have to specify either input_ids or inputs_embeds") raise ValueError("You have to specify either input_ids or inputs_embeds")
if past is None: if past_key_values is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) past_key_values = [None] * len(self.h)
else: else:
past_length = past[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
...@@ -431,7 +443,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -431,7 +443,7 @@ class CTRLModel(CTRLPreTrainedModel):
presents = () if use_cache else None presents = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = [] if output_attentions else None all_attentions = [] if output_attentions else None
for i, (h, layer_past) in enumerate(zip(self.h, past)): for i, (h, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h( outputs = h(
...@@ -492,7 +504,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -492,7 +504,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -504,7 +516,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -504,7 +516,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
past=None, past_key_values=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -515,6 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -515,6 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -524,11 +537,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -524,11 +537,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
All labels set to ``-100`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
""" """
if "past" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
past=past, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
......
...@@ -531,7 +531,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel): ...@@ -531,7 +531,7 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("masked_lm_labels") labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
......
...@@ -622,7 +622,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel): ...@@ -622,7 +622,7 @@ class ElectraForMaskedLM(ElectraPreTrainedModel):
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("masked_lm_labels") labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
......
...@@ -347,10 +347,12 @@ GPT2_START_DOCSTRING = r""" ...@@ -347,10 +347,12 @@ GPT2_START_DOCSTRING = r"""
GPT2_INPUTS_DOCSTRING = r""" GPT2_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if ``past`` is ``None`` else ``past[0].shape[-2]`` (``sequence_length`` of input past key value states). :obj:`input_ids_length` = ``sequence_length`` if ``past_key_values`` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states).
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
If `past` is used, only `input_ids` that do not have their past calculated should be passed as `input_ids`. If ``past_key_values`` is used, only ``input_ids`` that do not have their past calculated should be passed
as ``input_ids``.
Indices can be obtained using :class:`transformers.GPT2Tokenizer`. Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
...@@ -358,10 +360,10 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -358,10 +360,10 @@ GPT2_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. (see ``past_key_values`` output below). Can be used to speed up sequential decoding.
The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed. The ``input_ids`` which have their past given to this model should not be passed as ``input_ids`` as they have already been computed.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -386,9 +388,9 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -386,9 +388,9 @@ GPT2_INPUTS_DOCSTRING = r"""
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `inputs_embeds` have to be input (see `past`). If ``past_key_values`` is used, optionally only the last `inputs_embeds` have to be input (see ``past_key_values``).
use_cache (:obj:`bool`): use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`. If `use_cache` is True, ``past_key_values`` key value states are returned and can be used to speed up decoding (see ``past_key_values``). Defaults to `True`.
output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`):
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
...@@ -437,7 +439,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -437,7 +439,7 @@ class GPT2Model(GPT2PreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
past=None, past_key_values=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -447,7 +449,16 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -447,7 +449,16 @@ class GPT2Model(GPT2PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs,
): ):
if "past" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -472,11 +483,11 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -472,11 +483,11 @@ class GPT2Model(GPT2PreTrainedModel):
if position_ids is not None: if position_ids is not None:
position_ids = position_ids.view(-1, input_shape[-1]) position_ids = position_ids.view(-1, input_shape[-1])
if past is None: if past_key_values is None:
past_length = 0 past_length = 0
past = [None] * len(self.h) past_key_values = [None] * len(self.h)
else: else:
past_length = past[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
...@@ -522,7 +533,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -522,7 +533,7 @@ class GPT2Model(GPT2PreTrainedModel):
presents = () if use_cache else None presents = () if use_cache else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
...@@ -581,7 +592,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -581,7 +592,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]} return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
...@@ -593,7 +604,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -593,7 +604,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
past=None, past_key_values=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -604,6 +615,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -604,6 +615,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -613,11 +625,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -613,11 +625,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
All labels set to ``-100`` are ignored (masked), the loss is only All labels set to ``-100`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]`` computed for labels in ``[0, ..., config.vocab_size]``
""" """
if "past" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
past=past, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
...@@ -680,7 +699,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -680,7 +699,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
past=None, past_key_values=None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids=None,
position_ids=None, position_ids=None,
...@@ -693,7 +712,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -693,7 +712,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs **kwargs,
): ):
r""" r"""
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input) mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
...@@ -741,15 +760,21 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -741,15 +760,21 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
if "lm_labels" in kwargs: if "lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("lm_labels") labels = kwargs.pop("lm_labels")
if "past" in kwargs:
warnings.warn(
"The `past` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
transformer_outputs = self.transformer( transformer_outputs = self.transformer(
input_ids, input_ids,
past=past, past_key_values=past_key_values,
attention_mask=attention_mask, attention_mask=attention_mask,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
position_ids=position_ids, position_ids=position_ids,
......
...@@ -1094,7 +1094,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ...@@ -1094,7 +1094,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("masked_lm_labels") labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
......
...@@ -665,7 +665,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -665,7 +665,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
if "lm_labels" in kwargs: if "lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("lm_labels") labels = kwargs.pop("lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
......
...@@ -223,7 +223,7 @@ class RobertaForMaskedLM(BertPreTrainedModel): ...@@ -223,7 +223,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
if "masked_lm_labels" in kwargs: if "masked_lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `masked_lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("masked_lm_labels") labels = kwargs.pop("masked_lm_labels")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
......
...@@ -836,27 +836,27 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -836,27 +836,27 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder. Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`). If `decoder_past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_values`).
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None, `T5 Training <./t5.html#training>`__. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_input_ids takes the value of input_ids. decoder_input_ids takes the value of input_ids.
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): decoder_past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks. Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding. Can be used to speed up decoding.
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` If `decoder_past_key_values` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`). If `use_cache` is True, `decoder_past_key_values` are returned and can be used to speed up decoding (see `decoder_past_key_values`).
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`). If `decoder_past_key_values` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_values`).
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None, than the model's internal embedding lookup matrix. If decoder_input_ids and decoder_inputs_embeds are both None,
decoder_inputs_embeds takes the value of inputs_embeds. decoder_inputs_embeds takes the value of inputs_embeds.
...@@ -923,7 +923,7 @@ class T5Model(T5PreTrainedModel): ...@@ -923,7 +923,7 @@ class T5Model(T5PreTrainedModel):
encoder_outputs=None, encoder_outputs=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_past_key_value_states=None, decoder_past_key_values=None,
use_cache=None, use_cache=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
...@@ -931,6 +931,7 @@ class T5Model(T5PreTrainedModel): ...@@ -931,6 +931,7 @@ class T5Model(T5PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs,
): ):
r""" r"""
Returns: Returns:
...@@ -947,6 +948,14 @@ class T5Model(T5PreTrainedModel): ...@@ -947,6 +948,14 @@ class T5Model(T5PreTrainedModel):
>>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple >>> last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
...@@ -978,7 +987,7 @@ class T5Model(T5PreTrainedModel): ...@@ -978,7 +987,7 @@ class T5Model(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens # If decoding with past key value states, only the last tokens
# should be given as an input # should be given as an input
if decoder_past_key_value_states is not None: if decoder_past_key_values is not None:
if decoder_input_ids is not None: if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None: if decoder_inputs_embeds is not None:
...@@ -989,7 +998,7 @@ class T5Model(T5PreTrainedModel): ...@@ -989,7 +998,7 @@ class T5Model(T5PreTrainedModel):
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_value_states, past_key_value_states=decoder_past_key_values,
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -1062,7 +1071,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1062,7 +1071,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_outputs=None, encoder_outputs=None,
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_past_key_value_states=None, decoder_past_key_values=None,
use_cache=None, use_cache=None,
labels=None, labels=None,
inputs_embeds=None, inputs_embeds=None,
...@@ -1071,7 +1080,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1071,7 +1080,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_tuple=None, return_tuple=None,
**kwargs **kwargs,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1103,9 +1112,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1103,9 +1112,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
if "lm_labels" in kwargs: if "lm_labels" in kwargs:
warnings.warn( warnings.warn(
"The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.", "The `lm_labels` argument is deprecated and will be removed in a future version, use `labels` instead.",
DeprecationWarning, FutureWarning,
) )
labels = kwargs.pop("lm_labels") labels = kwargs.pop("lm_labels")
if "decoder_past_key_value_states" in kwargs:
warnings.warn(
"The `decoder_past_key_value_states` argument is deprecated and will be removed in a future version, use `decoder_past_key_values` instead.",
FutureWarning,
)
decoder_past_key_values = kwargs.pop("decoder_past_key_value_states")
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}." assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
...@@ -1138,7 +1153,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1138,7 +1153,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens # If decoding with past key value states, only the last tokens
# should be given as an input # should be given as an input
if decoder_past_key_value_states is not None: if decoder_past_key_values is not None:
assert labels is None, "Decoder should not use cached key value states when training." assert labels is None, "Decoder should not use cached key value states when training."
if decoder_input_ids is not None: if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
...@@ -1150,7 +1165,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1150,7 +1165,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
input_ids=decoder_input_ids, input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask, attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds, inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_value_states, past_key_value_states=decoder_past_key_values,
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
...@@ -1193,11 +1208,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1193,11 +1208,11 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs): def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs" assert past is not None, "past has to be defined for encoder_outputs"
encoder_outputs, decoder_past_key_value_states = past encoder_outputs, decoder_past_key_values = past
return { return {
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
"decoder_past_key_value_states": decoder_past_key_value_states, "decoder_past_key_values": decoder_past_key_values,
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"use_cache": use_cache, "use_cache": use_cache,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment