"...git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "b13539c3ddcf370977dade661702e5e62801ba0d"
Unverified Commit 01c37dcd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Config, Caching] Remove `output_past` everywhere and replace by `use_cache` argument (#3734)

* remove output_past from pt

* make style

* add optional input length for gpt2

* add use cache to prepare input

* save memory in gpt2

* correct gpt2 test inputs

* make past input optional for gpt2

* finish use_cache for all models

* make style

* delete modeling_gpt2 change in test file

* correct docstring

* correct is true statements for gpt2
parent 092cf881
...@@ -19,7 +19,7 @@ class BartSystem(BaseTransformer): ...@@ -19,7 +19,7 @@ class BartSystem(BaseTransformer):
mode = "language-modeling" mode = "language-modeling"
def __init__(self, hparams): def __init__(self, hparams):
super().__init__(hparams, num_labels=None, mode=self.mode, output_past=False) super().__init__(hparams, num_labels=None, mode=self.mode)
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None): def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
return self.model( return self.model(
......
...@@ -59,7 +59,7 @@ class PretrainedConfig(object): ...@@ -59,7 +59,7 @@ class PretrainedConfig(object):
# Attributes with defaults # Attributes with defaults
self.output_attentions = kwargs.pop("output_attentions", False) self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False) self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_past = kwargs.pop("output_past", True) # Not used by all models self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False) self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {}) self.pruned_heads = kwargs.pop("pruned_heads", {})
......
...@@ -933,7 +933,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -933,7 +933,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs): def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs" assert past is not None, "past has to be defined for encoder_outputs"
# first step, decoder_cached_states are empty # first step, decoder_cached_states are empty
...@@ -947,7 +947,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -947,7 +947,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states, "decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"use_cache": True, # change this to avoid caching (presumably for debugging) "use_cache": use_cache, # change this to avoid caching (presumably for debugging)
} }
def prepare_scores_for_generation(self, scores, cur_len, max_length): def prepare_scores_for_generation(self, scores, cur_len, max_length):
...@@ -980,10 +980,6 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -980,10 +980,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly return _make_linear_from_emb(self.model.shared) # make it on the fly
def _do_output_past(self, *args, **kwargs):
""" We should always use the cache in generate."""
return True
@add_start_docstrings( @add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, """Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
......
...@@ -98,7 +98,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -98,7 +98,7 @@ class MultiHeadAttention(torch.nn.Module):
x = x.reshape(batch_size, -1, self.num_heads, self.depth) x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3]) return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None): def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
batch_size = q.shape[0] batch_size = q.shape[0]
q = self.Wq(q) q = self.Wq(q)
...@@ -112,7 +112,11 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -112,7 +112,11 @@ class MultiHeadAttention(torch.nn.Module):
past_key, past_value = layer_past[0], layer_past[1] past_key, past_value = layer_past[0], layer_past[1]
k = torch.cat((past_key, k), dim=-2) k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2) v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k, v))
if use_cache is True:
present = torch.stack((k, v))
else:
present = (None,)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = output[0].permute([0, 2, 1, 3]) scaled_attention = output[0].permute([0, 2, 1, 3])
...@@ -143,10 +147,17 @@ class EncoderLayer(torch.nn.Module): ...@@ -143,10 +147,17 @@ class EncoderLayer(torch.nn.Module):
self.dropout1 = torch.nn.Dropout(rate) self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate) self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None): def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
normed = self.layernorm1(x) normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention( attn_outputs = self.multi_head_attention(
normed, normed, normed, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask normed,
normed,
normed,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output) attn_output = self.dropout1(attn_output)
...@@ -199,6 +210,7 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -199,6 +210,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.CTRLTokenizer`. Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
...@@ -207,8 +219,10 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -207,8 +219,10 @@ CTRL_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model (see `past` output below). Can be used to speed up sequential decoding.
should not be passed as input ids as they have already been computed. If `past` is used, the user can optionally input only the last `input_ids`
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -219,6 +233,7 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -219,6 +233,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -234,6 +249,10 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -234,6 +249,10 @@ CTRL_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`.
""" """
...@@ -246,7 +265,6 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -246,7 +265,6 @@ class CTRLModel(CTRLPreTrainedModel):
super().__init__(config) super().__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
...@@ -289,6 +307,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -289,6 +307,7 @@ class CTRLModel(CTRLPreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
): ):
r""" r"""
Return: Return:
...@@ -297,8 +316,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -297,8 +316,7 @@ class CTRLModel(CTRLPreTrainedModel):
Sequence of hidden-states at the last layer of the model. Sequence of hidden-states at the last layer of the model.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`): past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks). Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model Can be used (see `past` input) to speed up sequential decoding.
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -325,6 +343,17 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -325,6 +343,17 @@ class CTRLModel(CTRLPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -414,10 +443,15 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -414,10 +443,15 @@ class CTRLModel(CTRLPreTrainedModel):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h( outputs = h(
hidden_states, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i] hidden_states,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past: if use_cache is True:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -429,7 +463,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -429,7 +463,7 @@ class CTRLModel(CTRLPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_past: if use_cache is True:
outputs = outputs + (presents,) outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
...@@ -462,7 +496,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -462,7 +496,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past} return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def forward( def forward(
...@@ -475,6 +509,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -475,6 +509,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=True,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -492,8 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -492,8 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`): past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks). Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model Can be used (see `past` input) to speed up sequential decoding.
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -527,6 +561,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -527,6 +561,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -177,7 +177,7 @@ class Attention(nn.Module): ...@@ -177,7 +177,7 @@ class Attention(nn.Module):
else: else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2) query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query) query = self.split_heads(query)
...@@ -187,7 +187,11 @@ class Attention(nn.Module): ...@@ -187,7 +187,11 @@ class Attention(nn.Module):
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1) key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
else:
present = (None,)
attn_outputs = self._attn(query, key, value, attention_mask, head_mask) attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
a = attn_outputs[0] a = attn_outputs[0]
...@@ -224,9 +228,13 @@ class Block(nn.Module): ...@@ -224,9 +228,13 @@ class Block(nn.Module):
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config) self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
output_attn = self.attn( output_attn = self.attn(
self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask self.ln_1(x),
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
) )
a = output_attn[0] # output_attn: a, present, (attentions) a = output_attn[0] # output_attn: a, present, (attentions)
...@@ -279,10 +287,9 @@ GPT2_START_DOCSTRING = r""" ...@@ -279,10 +287,9 @@ GPT2_START_DOCSTRING = r"""
GPT2_INPUTS_DOCSTRING = r""" GPT2_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
`input_ids_length` = `sequence_length if `past` is None else 1
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
If using `past` as an input make sure that `input_ids` are those of the last position. If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.GPT2Tokenizer`. Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
...@@ -292,8 +299,8 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -292,8 +299,8 @@ GPT2_INPUTS_DOCSTRING = r"""
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model (see `past` output below). Can be used to speed up sequential decoding.
should not be passed as input ids as they have already been computed. If `past` is used, the user can optionally input only the last `input_ids` (those that don't have their past given to this model) of shape :obj:`(batch_size, 1)` instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -305,7 +312,7 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -305,7 +312,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
If using `past` as an input make sure that `token_type_ids` correspond to the `input_ids` of the last position. If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -321,6 +328,9 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -321,6 +328,9 @@ GPT2_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
""" """
...@@ -333,7 +343,6 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -333,7 +343,6 @@ class GPT2Model(GPT2PreTrainedModel):
super().__init__(config) super().__init__(config)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.wte = nn.Embedding(config.vocab_size, config.n_embd) self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd) self.wpe = nn.Embedding(config.n_positions, config.n_embd)
...@@ -366,16 +375,17 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -366,16 +375,17 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
): ):
r""" r"""
Return: Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs: :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model. Sequence of hidden-states at the last layer of the model.
If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`): past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks). Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model Can be used (see `past` input) to speed up sequential decoding.
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -400,6 +410,17 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -400,6 +410,17 @@ class GPT2Model(GPT2PreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
""" """
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -484,11 +505,15 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -484,11 +505,15 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block( outputs = block(
hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i] hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
) )
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past: if use_cache is True:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -502,7 +527,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -502,7 +527,7 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_past: if use_cache is True:
outputs = outputs + (presents,) outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
...@@ -535,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -535,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past} return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
def forward( def forward(
...@@ -548,6 +573,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -548,6 +573,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
labels=None, labels=None,
use_cache=True,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -565,8 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -565,8 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`): past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks). Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model Can be used (see `past` input) to speed up sequential decoding.
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -600,6 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -600,6 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
...@@ -652,6 +678,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -652,6 +678,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_token_ids=None, mc_token_ids=None,
lm_labels=None, lm_labels=None,
mc_labels=None, mc_labels=None,
use_cache=True,
): ):
r""" r"""
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input) mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
...@@ -680,8 +707,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -680,8 +707,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`): past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks). Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model Can be used (see `past` input) to speed up sequential decoding.
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`. of shape :obj:`(batch_size, sequence_length, hidden_size)`.
...@@ -726,6 +752,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -726,6 +752,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
......
...@@ -188,7 +188,6 @@ class T5Attention(nn.Module): ...@@ -188,7 +188,6 @@ class T5Attention(nn.Module):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
self.output_past = config.output_past
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
...@@ -300,6 +299,7 @@ class T5Attention(nn.Module): ...@@ -300,6 +299,7 @@ class T5Attention(nn.Module):
past_key_value_state=None, past_key_value_state=None,
head_mask=None, head_mask=None,
query_length=None, query_length=None,
use_cache=False,
): ):
""" """
Self-attention (if kv is None) or attention over source sentence (provided by kv). Self-attention (if kv is None) or attention over source sentence (provided by kv).
...@@ -351,7 +351,7 @@ class T5Attention(nn.Module): ...@@ -351,7 +351,7 @@ class T5Attention(nn.Module):
else: else:
k, v = past_key_value_state k, v = past_key_value_state
if self.is_decoder and self.output_past: if self.is_decoder and use_cache:
present_key_value_state = ((k, v),) present_key_value_state = ((k, v),)
else: else:
present_key_value_state = (None,) present_key_value_state = (None,)
...@@ -385,14 +385,8 @@ class T5Attention(nn.Module): ...@@ -385,14 +385,8 @@ class T5Attention(nn.Module):
context = self.o(context) context = self.o(context)
outputs = (context,) outputs = (context,) + present_key_value_state
if self.output_past is False or self.is_decoder is False:
assert (
present_key_value_state[0] is None
), "Key/Value projections should not be stored if {} is not decoder or output_past is False".format(self)
outputs = outputs + present_key_value_state
if self.output_attentions: if self.output_attentions:
outputs = outputs + (weights,) outputs = outputs + (weights,)
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
...@@ -408,7 +402,13 @@ class T5LayerSelfAttention(nn.Module): ...@@ -408,7 +402,13 @@ class T5LayerSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
def forward( def forward(
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None self,
hidden_states,
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
use_cache=False,
): ):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
...@@ -417,6 +417,7 @@ class T5LayerSelfAttention(nn.Module): ...@@ -417,6 +417,7 @@ class T5LayerSelfAttention(nn.Module):
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache,
) )
y = attention_output[0] y = attention_output[0]
layer_output = hidden_states + self.dropout(y) layer_output = hidden_states + self.dropout(y)
...@@ -439,6 +440,7 @@ class T5LayerCrossAttention(nn.Module): ...@@ -439,6 +440,7 @@ class T5LayerCrossAttention(nn.Module):
position_bias=None, position_bias=None,
head_mask=None, head_mask=None,
past_key_value_state=None, past_key_value_state=None,
use_cache=False,
query_length=None, query_length=None,
): ):
norm_x = self.layer_norm(hidden_states) norm_x = self.layer_norm(hidden_states)
...@@ -449,6 +451,7 @@ class T5LayerCrossAttention(nn.Module): ...@@ -449,6 +451,7 @@ class T5LayerCrossAttention(nn.Module):
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache,
query_length=query_length, query_length=query_length,
) )
y = attention_output[0] y = attention_output[0]
...@@ -460,7 +463,6 @@ class T5LayerCrossAttention(nn.Module): ...@@ -460,7 +463,6 @@ class T5LayerCrossAttention(nn.Module):
class T5Block(nn.Module): class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False):
super().__init__() super().__init__()
self.output_past = config.output_past
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
...@@ -479,6 +481,7 @@ class T5Block(nn.Module): ...@@ -479,6 +481,7 @@ class T5Block(nn.Module):
encoder_decoder_position_bias=None, encoder_decoder_position_bias=None,
head_mask=None, head_mask=None,
past_key_value_state=None, past_key_value_state=None,
use_cache=False,
): ):
if past_key_value_state is not None: if past_key_value_state is not None:
...@@ -499,6 +502,7 @@ class T5Block(nn.Module): ...@@ -499,6 +502,7 @@ class T5Block(nn.Module):
position_bias=position_bias, position_bias=position_bias,
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state, past_key_value_state=self_attn_past_key_value_state,
use_cache=use_cache,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
...@@ -519,6 +523,7 @@ class T5Block(nn.Module): ...@@ -519,6 +523,7 @@ class T5Block(nn.Module):
head_mask=head_mask, head_mask=head_mask,
past_key_value_state=cross_attn_past_key_value_state, past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length, query_length=query_length,
use_cache=use_cache,
) )
hidden_states = cross_attention_outputs[0] hidden_states = cross_attention_outputs[0]
# Combine self attn and cross attn key value states # Combine self attn and cross attn key value states
...@@ -620,7 +625,6 @@ class T5Stack(T5PreTrainedModel): ...@@ -620,7 +625,6 @@ class T5Stack(T5PreTrainedModel):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.output_past = config.output_past and self.is_decoder
self.block = nn.ModuleList( self.block = nn.ModuleList(
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
...@@ -648,6 +652,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -648,6 +652,7 @@ class T5Stack(T5PreTrainedModel):
inputs_embeds=None, inputs_embeds=None,
head_mask=None, head_mask=None,
past_key_value_states=None, past_key_value_states=None,
use_cache=False,
): ):
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
...@@ -699,7 +704,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -699,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None] causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(attention_mask) causal_mask = causal_mask.to(attention_mask)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if self.output_past and past_key_value_states[0] is not None: if past_key_value_states[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -1:, :] extended_attention_mask = extended_attention_mask[:, :, -1:, :]
else: else:
extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = attention_mask[:, None, None, :]
...@@ -776,6 +781,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -776,6 +781,7 @@ class T5Stack(T5PreTrainedModel):
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i], head_mask=head_mask[i],
past_key_value_state=past_key_value_state, past_key_value_state=past_key_value_state,
use_cache=use_cache,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
...@@ -800,7 +806,8 @@ class T5Stack(T5PreTrainedModel): ...@@ -800,7 +806,8 @@ class T5Stack(T5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if self.is_decoder and self.output_past: if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
outputs = outputs + (present_key_value_states,) outputs = outputs + (present_key_value_states,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
...@@ -833,7 +840,7 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -833,7 +840,7 @@ T5_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. If `decoder_past_key_value_states` is used, optionally only the last `input_ids` have to be input (see `decoder_past_key_value_states`). T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left.
Indices can be obtained using :class:`transformers.T5Tokenizer`. Indices can be obtained using :class:`transformers.T5Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
...@@ -849,19 +856,26 @@ T5_INPUTS_DOCSTRING = r""" ...@@ -849,19 +856,26 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder. Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`): decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation. Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ . `T5 Training <./t5.html#training>`_ .
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`): decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default. Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks. Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding. If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. Can be used to speed up decoding.
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`): decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`): head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
...@@ -897,14 +911,6 @@ class T5Model(T5PreTrainedModel): ...@@ -897,14 +911,6 @@ class T5Model(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings) self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings) self.decoder.set_input_embeddings(new_embeddings)
def set_output_past(self, do_output_past: bool):
self.config.output_past = do_output_past
self.decoder.output_past = do_output_past
for block in self.decoder.block:
block.output_past = do_output_past
block.layer[0].SelfAttention.output_past = do_output_past
block.layer[1].EncDecAttention.output_past = do_output_past
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
...@@ -928,6 +934,7 @@ class T5Model(T5PreTrainedModel): ...@@ -928,6 +934,7 @@ class T5Model(T5PreTrainedModel):
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_past_key_value_states=None, decoder_past_key_value_states=None,
use_cache=True,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
head_mask=None, head_mask=None,
...@@ -938,7 +945,7 @@ class T5Model(T5PreTrainedModel): ...@@ -938,7 +945,7 @@ class T5Model(T5PreTrainedModel):
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model. Sequence of hidden-states at the output of the last layer of the model.
If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``): decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks. Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
...@@ -976,7 +983,7 @@ class T5Model(T5PreTrainedModel): ...@@ -976,7 +983,7 @@ class T5Model(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens # If decoding with past key value states, only the last tokens
# should be given as an input # should be given as an input
if decoder_past_key_value_states is not None and self.decoder.output_past is True: if decoder_past_key_value_states is not None:
if decoder_input_ids is not None: if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None: if decoder_inputs_embeds is not None:
...@@ -991,9 +998,10 @@ class T5Model(T5PreTrainedModel): ...@@ -991,9 +998,10 @@ class T5Model(T5PreTrainedModel):
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache,
) )
if self.decoder.output_past: if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),) past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
...@@ -1022,14 +1030,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1022,14 +1030,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
def set_output_past(self, do_output_past: bool):
self.config.output_past = do_output_past
self.decoder.output_past = do_output_past
for block in self.decoder.block:
block.output_past = do_output_past
block.layer[0].SelfAttention.output_past = do_output_past
block.layer[1].EncDecAttention.output_past = do_output_past
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings) self.encoder.set_input_embeddings(new_embeddings)
...@@ -1053,6 +1053,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1053,6 +1053,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_input_ids=None, decoder_input_ids=None,
decoder_attention_mask=None, decoder_attention_mask=None,
decoder_past_key_value_states=None, decoder_past_key_value_states=None,
use_cache=True,
lm_labels=None, lm_labels=None,
inputs_embeds=None, inputs_embeds=None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
...@@ -1072,7 +1073,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1072,7 +1073,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`) prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output. If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``): decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks. Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input). Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`. Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
...@@ -1116,10 +1117,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1116,10 +1117,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens # If decoding with past key value states, only the last tokens
# should be given as an input # should be given as an input
if decoder_past_key_value_states is not None and self.decoder.output_past is True: if decoder_past_key_value_states is not None:
assert ( assert lm_labels is None, "Decoder should not use cached key value states when training."
lm_labels is None
), "Decoder should not use cached key value states when training. Also consider setting model.set_output_past(False) for less memory consumption"
if decoder_input_ids is not None: if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None: if decoder_inputs_embeds is not None:
...@@ -1134,11 +1133,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1134,11 +1133,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_hidden_states=hidden_states, encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache,
) )
# insert decoder past at right place # insert decoder past at right place
# to speed up decoding # to speed up decoding
if self.decoder.output_past: if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),) past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
...@@ -1157,7 +1157,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1157,7 +1157,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs" assert past is not None, "past has to be defined for encoder_outputs"
# first step # first step
...@@ -1171,13 +1171,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1171,13 +1171,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
"decoder_past_key_value_states": decoder_past_key_value_states, "decoder_past_key_value_states": decoder_past_key_value_states,
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"use_cache": use_cache,
} }
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output # if decoder past is not included in output
# speedy decoding is disabled and no need to reorder # speedy decoding is disabled and no need to reorder
if len(past) < 2: if len(past) < 2:
logger.warning("You might want to consider setting model.set_output_past(True) to speed up decoding") logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past return past
decoder_past = past[1] decoder_past = past[1]
......
...@@ -94,7 +94,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -94,7 +94,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.transpose(x, perm=[0, 2, 1, 3]) return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False): def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask = inputs v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs
batch_size = shape_list(q)[0] batch_size = shape_list(q)[0]
q = self.Wq(q) q = self.Wq(q)
...@@ -104,11 +104,25 @@ class TFMultiHeadAttention(tf.keras.layers.Layer): ...@@ -104,11 +104,25 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
q = self.split_into_heads(q, batch_size) q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size) k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size) v = self.split_into_heads(v, batch_size)
if layer_past is not None: if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=0) past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2) k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2) v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=0)
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True:
present = tf.stack((k, v), axis=0)
else:
present = (None,)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask) output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3]) scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
...@@ -147,10 +161,10 @@ class TFEncoderLayer(tf.keras.layers.Layer): ...@@ -147,10 +161,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout2 = tf.keras.layers.Dropout(rate) self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask = inputs x, mask, layer_past, attention_mask, head_mask, use_cache = inputs
normed = self.layernorm1(x) normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention( attn_outputs = self.multi_head_attention(
[normed, normed, normed, mask, layer_past, attention_mask, head_mask], training=training [normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training
) )
attn_output = attn_outputs[0] attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output, training=training) attn_output = self.dropout1(attn_output, training=training)
...@@ -173,7 +187,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -173,7 +187,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd self.d_model_size = config.n_embd
self.num_layers = config.n_layer self.num_layers = config.n_layer
...@@ -220,8 +233,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -220,8 +233,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
input_ids = inputs[0] input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past past = inputs[1] if len(inputs) > 1 else past
...@@ -230,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -230,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs." use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
...@@ -239,10 +255,21 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -239,10 +255,21 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 7, "Too many inputs." use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -319,10 +346,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -319,10 +346,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
for i, (h, layer_past) in enumerate(zip(self.h, past)): for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training) outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training)
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
if self.output_past: if use_cache is True:
presents = presents + (present,) presents = presents + (present,)
if self.output_attentions: if self.output_attentions:
...@@ -334,7 +361,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer): ...@@ -334,7 +361,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,) outputs = (hidden_states,)
if self.output_past: if use_cache is True:
outputs = outputs + (presents,) outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
...@@ -386,6 +413,7 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -386,6 +413,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.CTRLTokenizer`. Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
...@@ -394,8 +422,10 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -394,8 +422,10 @@ CTRL_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`): past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model (see `past` output below). Can be used to speed up sequential decoding.
should not be passed as input ids as they have already been computed. If `past` is used, the user can optionally input only the last `input_ids`
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices. Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``: Mask values selected in ``[0, 1]``:
...@@ -406,6 +436,7 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -406,6 +436,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -421,6 +452,10 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -421,6 +452,10 @@ CTRL_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`.
training (:obj:`boolean`, `optional`, defaults to :obj:`False`): training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
(if set to :obj:`False`) for evaluation. (if set to :obj:`False`) for evaluation.
...@@ -514,7 +549,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel): ...@@ -514,7 +549,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past} return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
......
...@@ -134,7 +134,7 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -134,7 +134,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features) return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask = inputs x, layer_past, attention_mask, head_mask, use_cache = inputs
x = self.c_attn(x) x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2) query, key, value = tf.split(x, 3, axis=2)
...@@ -145,7 +145,20 @@ class TFAttention(tf.keras.layers.Layer): ...@@ -145,7 +145,20 @@ class TFAttention(tf.keras.layers.Layer):
past_key, past_value = tf.unstack(layer_past, axis=0) past_key, past_value = tf.unstack(layer_past, axis=0)
key = tf.concat([past_key, key], axis=-2) key = tf.concat([past_key, key], axis=-2)
value = tf.concat([past_value, value], axis=-2) value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=0)
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True:
present = tf.stack([key, value], axis=0)
else:
present = (None,)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training) attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
a = attn_outputs[0] a = attn_outputs[0]
...@@ -184,10 +197,10 @@ class TFBlock(tf.keras.layers.Layer): ...@@ -184,10 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
self.mlp = TFMLP(4 * nx, config, name="mlp") self.mlp = TFMLP(4 * nx, config, name="mlp")
def call(self, inputs, training=False): def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask = inputs x, layer_past, attention_mask, head_mask, use_cache = inputs
a = self.ln_1(x) a = self.ln_1(x)
output_attn = self.attn([a, layer_past, attention_mask, head_mask], training=training) output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training)
a = output_attn[0] # output_attn: a, present, (attentions) a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a x = x + a
...@@ -245,6 +258,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -245,6 +258,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -255,7 +269,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -255,7 +269,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids = inputs[4] if len(inputs) > 4 else position_ids position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs." use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
...@@ -264,10 +279,21 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -264,10 +279,21 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids) position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 7, "Too many inputs." use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None: if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None: elif input_ids is not None:
...@@ -338,7 +364,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -338,7 +364,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i]], training=training) outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training)
hidden_states, present = outputs[:2] hidden_states, present = outputs[:2]
presents = presents + (present,) presents = presents + (present,)
...@@ -353,7 +379,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer): ...@@ -353,7 +379,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if self.output_hidden_states: if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents) outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states: if self.output_hidden_states:
outputs = outputs + (all_hidden_states,) outputs = outputs + (all_hidden_states,)
if self.output_attentions: if self.output_attentions:
...@@ -404,6 +433,7 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -404,6 +433,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`): input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.GPT2Tokenizer`. Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and See :func:`transformers.PreTrainedTokenizer.encode` and
...@@ -424,6 +454,7 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -424,6 +454,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs. Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_ `What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`): position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
...@@ -439,6 +470,7 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -439,6 +470,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
training (:obj:`boolean`, `optional`, defaults to :obj:`False`): training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
(if set to :obj:`False`) for evaluation. (if set to :obj:`False`) for evaluation.
...@@ -511,7 +543,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel): ...@@ -511,7 +543,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
if past: if past:
inputs = tf.expand_dims(inputs[:, -1], -1) inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past} return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs): def call(self, inputs, **kwargs):
...@@ -590,6 +622,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -590,6 +622,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
mc_token_ids=None, mc_token_ids=None,
use_cache=True,
training=False, training=False,
): ):
r""" r"""
...@@ -656,7 +689,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -656,7 +689,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask = inputs[5] if len(inputs) > 5 else head_mask head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
assert len(inputs) <= 8, "Too many inputs." use_cache = inputs[8] if len(inputs) > 8 else use_cache
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
past = inputs.get("past", past) past = inputs.get("past", past)
...@@ -666,7 +700,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -666,7 +700,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids) mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
assert len(inputs) <= 8, "Too many inputs." use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 9, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -690,6 +725,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel): ...@@ -690,6 +725,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
flat_position_ids, flat_position_ids,
head_mask, head_mask,
inputs_embeds, inputs_embeds,
use_cache,
] ]
transformer_outputs = self.transformer(flat_inputs, training=training) transformer_outputs = self.transformer(flat_inputs, training=training)
......
...@@ -444,16 +444,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -444,16 +444,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
def prepare_inputs_for_generation(self, inputs, **kwargs): def prepare_inputs_for_generation(self, inputs, **kwargs):
return {"inputs": inputs} return {"inputs": inputs}
def _do_output_past(self, outputs): def _use_cache(self, outputs, use_cache):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past """During generation, decide whether to pass the `past` variable to the next forward pass."""
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len if len(outputs) <= 1 or use_cache is False:
return False
if has_output_past and not has_mem_len and len(outputs) > 1: if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return True return False
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1: return True
return True
return False
def generate( def generate(
self, self,
...@@ -476,6 +473,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -476,6 +473,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_return_sequences=None, num_return_sequences=None,
attention_mask=None, attention_mask=None,
decoder_start_token_id=None, decoder_start_token_id=None,
use_cache=None,
): ):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search. and beam-search.
...@@ -551,6 +549,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -551,6 +549,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
If an encoder-decoder model starts decoding with a different token than BOS. If an encoder-decoder model starts decoding with a different token than BOS.
Defaults to `None` and is changed to `BOS` later. Defaults to `None` and is changed to `BOS` later.
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
Return: Return:
output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)` output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
...@@ -605,6 +606,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -605,6 +606,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
min_length = min_length if min_length is not None else self.config.min_length min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k top_k = top_k if top_k is not None else self.config.top_k
...@@ -634,6 +636,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -634,6 +636,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be strictely positive." assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
...@@ -782,6 +785,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -782,6 +785,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size=vocab_size, vocab_size=vocab_size,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache,
) )
else: else:
output = self._generate_no_beam_search( output = self._generate_no_beam_search(
...@@ -804,6 +808,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -804,6 +808,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size=vocab_size, vocab_size=vocab_size,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache,
) )
return output return output
...@@ -829,6 +834,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -829,6 +834,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size, vocab_size,
encoder_outputs, encoder_outputs,
attention_mask, attention_mask,
use_cache,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """ Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
...@@ -841,12 +847,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -841,12 +847,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs): if self._use_cache(outputs, use_cache):
past = outputs[1] past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
...@@ -993,6 +1001,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -993,6 +1001,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size, vocab_size,
encoder_outputs, encoder_outputs,
attention_mask, attention_mask,
use_cache,
): ):
""" Generate sequences for each example with beam search. """ Generate sequences for each example with beam search.
""" """
...@@ -1020,12 +1029,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1020,12 +1029,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs): if self._use_cache(outputs, use_cache):
past = outputs[1] past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
......
...@@ -358,7 +358,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -358,7 +358,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs) super().__init__(**kwargs)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -503,6 +502,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -503,6 +502,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
training=False, training=False,
): ):
if isinstance(inputs, (tuple, list)): if isinstance(inputs, (tuple, list)):
...@@ -515,7 +515,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -515,7 +515,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask = inputs[6] if len(inputs) > 6 else input_mask input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs." use_cache = inputs[9] if len(inputs) > 9 else use_cache
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)): elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids") input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask) attention_mask = inputs.get("attention_mask", attention_mask)
...@@ -526,7 +527,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -526,7 +527,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask = inputs.get("input_mask", input_mask) input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask) head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds) inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs." use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 10, "Too many inputs."
else: else:
input_ids = inputs input_ids = inputs
...@@ -657,7 +659,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -657,7 +659,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
# cache new mems # cache new mems
if self.mem_len is not None and self.mem_len > 0 and self.output_past: if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states: if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h) hidden_states.append((output_h, output_g) if output_g is not None else output_h)
...@@ -679,7 +681,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer): ...@@ -679,7 +681,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (tf.transpose(output, perm=(1, 0, 2)),) outputs = (tf.transpose(output, perm=(1, 0, 2)),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past: if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
outputs = outputs + (new_mems,) outputs = outputs + (new_mems,)
if self.output_hidden_states: if self.output_hidden_states:
...@@ -783,6 +785,8 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -783,6 +785,8 @@ XLNET_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
use_cache (:obj:`bool`):
If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
""" """
...@@ -848,7 +852,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -848,7 +852,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss.input_embeddings return self.lm_loss.input_embeddings
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs): def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0] effective_batch_size = inputs.shape[0]
...@@ -866,7 +870,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel): ...@@ -866,7 +870,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32) target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1) target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {"inputs": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping} inputs = {
"inputs": inputs,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if past: if past:
......
...@@ -652,15 +652,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -652,15 +652,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def prepare_scores_for_generation(self, scores, **kwargs): def prepare_scores_for_generation(self, scores, **kwargs):
return scores return scores
def _do_output_past(self, outputs): def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass.""" """During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False) if len(outputs) <= 1 or use_cache is False:
mem_len = getattr(self.config, "mem_len", 0)
if len(outputs) <= 1:
return False return False
if mem_len > 0 or has_output_past: if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return True return False
return False return True
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty): def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ """repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
...@@ -694,6 +692,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -694,6 +692,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences=None, num_return_sequences=None,
attention_mask=None, attention_mask=None,
decoder_start_token_id=None, decoder_start_token_id=None,
use_cache=None,
): ):
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling. r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
...@@ -768,6 +767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -768,6 +767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
If an encoder-decoder model starts decoding with a different token than BOS. If an encoder-decoder model starts decoding with a different token than BOS.
Defaults to `None` and is changed to `BOS` later. Defaults to `None` and is changed to `BOS` later.
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
Return: Return:
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)` output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
...@@ -822,6 +824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -822,6 +824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
min_length = min_length if min_length is not None else self.config.min_length min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k top_k = top_k if top_k is not None else self.config.top_k
...@@ -851,6 +854,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -851,6 +854,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer." assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean." assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean." assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer." assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive." assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer." assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
...@@ -1011,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1011,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
vocab_size=vocab_size, vocab_size=vocab_size,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache,
) )
else: else:
output = self._generate_no_beam_search( output = self._generate_no_beam_search(
...@@ -1032,6 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1032,6 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size=effective_batch_size, batch_size=effective_batch_size,
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
attention_mask=attention_mask, attention_mask=attention_mask,
use_cache=use_cache,
) )
return output return output
...@@ -1056,6 +1062,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1056,6 +1062,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size, batch_size,
encoder_outputs, encoder_outputs,
attention_mask, attention_mask,
use_cache,
): ):
""" Generate sequences for each example without beam search (num_beams == 1). """ Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly. All returned sequence are generated independantly.
...@@ -1067,13 +1074,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1067,13 +1074,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs) outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :] next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs): if self._use_cache(outputs, use_cache):
past = outputs[1] past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
...@@ -1178,6 +1187,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1178,6 +1187,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
vocab_size, vocab_size,
encoder_outputs, encoder_outputs,
attention_mask, attention_mask,
use_cache,
): ):
""" Generate sequences for each example with beam search. """ Generate sequences for each example with beam search.
""" """
...@@ -1203,12 +1213,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin): ...@@ -1203,12 +1213,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
done = [False for _ in range(batch_size)] done = [False for _ in range(batch_size)]
while cur_len < max_length: while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask) model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size) outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size) next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding # if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs): if self._use_cache(outputs, use_cache):
past = outputs[1] past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858) # repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
......
...@@ -524,6 +524,7 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -524,6 +524,7 @@ XLNET_INPUTS_DOCSTRING = r"""
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems (see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
given to this model should not be passed as input ids as they have already been computed. given to this model should not be passed as input ids as they have already been computed.
`use_cache` has to be set to `True` to make use of `mems`.
perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`): perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``: Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
If ``perm_mask[k, i, j] = 0``, i attend to j in batch k; If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
...@@ -555,6 +556,8 @@ XLNET_INPUTS_DOCSTRING = r""" ...@@ -555,6 +556,8 @@ XLNET_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix. than the model's internal embedding lookup matrix.
use_cache (:obj:`bool`):
If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
""" """
...@@ -567,7 +570,6 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -567,7 +570,6 @@ class XLNetModel(XLNetPreTrainedModel):
super().__init__(config) super().__init__(config)
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.reuse_len = config.reuse_len self.reuse_len = config.reuse_len
...@@ -698,6 +700,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -698,6 +700,7 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
): ):
r""" r"""
Return: Return:
...@@ -864,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -864,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = [] attentions = []
hidden_states = [] hidden_states = []
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if self.mem_len is not None and self.mem_len > 0 and self.output_past: if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
# cache new mems # cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),) new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states: if self.output_hidden_states:
...@@ -894,7 +897,7 @@ class XLNetModel(XLNetPreTrainedModel): ...@@ -894,7 +897,7 @@ class XLNetModel(XLNetPreTrainedModel):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method) # Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (output.permute(1, 0, 2).contiguous(),) outputs = (output.permute(1, 0, 2).contiguous(),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past: if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
outputs = outputs + (new_mems,) outputs = outputs + (new_mems,)
if self.output_hidden_states: if self.output_hidden_states:
...@@ -935,7 +938,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -935,7 +938,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_loss return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# Add dummy token at the end (no attention on this one) # Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0] effective_batch_size = input_ids.shape[0]
...@@ -955,7 +958,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -955,7 +958,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
) )
target_mapping[0, 0, -1] = 1.0 target_mapping[0, 0, -1] = 1.0
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping} inputs = {
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
}
# if past is defined in model kwargs then use it for faster decoding # if past is defined in model kwargs then use it for faster decoding
if past: if past:
...@@ -975,6 +983,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -975,6 +983,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
labels=None, labels=None,
): ):
r""" r"""
...@@ -1050,6 +1059,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1050,6 +1059,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
logits = self.lm_loss(transformer_outputs[0]) logits = self.lm_loss(transformer_outputs[0])
...@@ -1093,6 +1103,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1093,6 +1103,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
labels=None, labels=None,
): ):
r""" r"""
...@@ -1148,6 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1148,6 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1196,6 +1208,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1196,6 +1208,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
labels=None, labels=None,
): ):
r""" r"""
...@@ -1252,6 +1265,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel): ...@@ -1252,6 +1265,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1301,9 +1315,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1301,9 +1315,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
mems=None, mems=None,
perm_mask=None, perm_mask=None,
target_mapping=None, target_mapping=None,
labels=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
labels=None,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
...@@ -1368,6 +1383,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel): ...@@ -1368,6 +1383,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
target_mapping=target_mapping, target_mapping=target_mapping,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
output = transformer_outputs[0] output = transformer_outputs[0]
...@@ -1414,6 +1430,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1414,6 +1430,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
): ):
...@@ -1478,6 +1495,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel): ...@@ -1478,6 +1495,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
sequence_output = outputs[0] sequence_output = outputs[0]
...@@ -1538,6 +1556,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1538,6 +1556,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
input_mask=None, input_mask=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
use_cache=True,
start_positions=None, start_positions=None,
end_positions=None, end_positions=None,
is_impossible=None, is_impossible=None,
...@@ -1616,6 +1635,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): ...@@ -1616,6 +1635,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
input_mask=input_mask, input_mask=input_mask,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
use_cache=use_cache,
) )
hidden_states = transformer_outputs[0] hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask) start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
...@@ -128,7 +128,6 @@ class ModelTesterMixin: ...@@ -128,7 +128,6 @@ class ModelTesterMixin:
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
config.output_attentions = True config.output_attentions = True
config.output_hidden_states = False config.output_hidden_states = False
config.output_past = False
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -664,10 +663,6 @@ class ModelTesterMixin: ...@@ -664,10 +663,6 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
if self.is_encoder_decoder:
# needed for Bart beam search
config.output_past = True
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
model = model_class(config) model = model_class(config)
......
...@@ -227,7 +227,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -227,7 +227,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model.eval() model.eval()
# first forward pass # first forward pass
output, past_key_value_states = model(input_ids) output, past_key_value_states = model(input_ids, use_cache=True)
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -235,8 +235,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -235,8 +235,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
# append to next input_ids and # append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1) next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past, _ = model(next_input_ids) output_from_no_past = model(next_input_ids)[0]
output_from_past, _ = model(next_tokens, past_key_value_states=past_key_value_states) output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -260,7 +260,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -260,7 +260,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
attn_mask[:, half_seq_length:] = 0 attn_mask[:, half_seq_length:] = 0
# first forward pass # first forward pass
output, past_key_value_states = model(input_ids, attention_mask=attn_mask) output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -277,10 +277,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -277,10 +277,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
) )
# get two different outputs # get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask) output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past, _ = model( output_from_past = model(
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
) )[0]
# select random slice # select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
...@@ -298,10 +298,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -298,10 +298,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
torch.manual_seed(0) torch.manual_seed(0)
model.set_output_past(False) output_without_past_cache = model.generate(
output_without_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False
)
torch.manual_seed(0) torch.manual_seed(0)
model.set_output_past(True)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True) output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache)) self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
...@@ -321,6 +321,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -321,6 +321,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"use_cache": False,
} }
return config, inputs_dict return config, inputs_dict
......
...@@ -460,10 +460,6 @@ class TFModelTesterMixin: ...@@ -460,10 +460,6 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"] input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
if self.is_encoder_decoder:
# needed for Bart beam search
config.output_past = True
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
model = model_class(config) model = model_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment