Unverified Commit 01c37dcd authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Config, Caching] Remove `output_past` everywhere and replace by `use_cache` argument (#3734)

* remove output_past from pt

* make style

* add optional input length for gpt2

* add use cache to prepare input

* save memory in gpt2

* correct gpt2 test inputs

* make past input optional for gpt2

* finish use_cache for all models

* make style

* delete modeling_gpt2 change in test file

* correct docstring

* correct is true statements for gpt2
parent 092cf881
......@@ -19,7 +19,7 @@ class BartSystem(BaseTransformer):
mode = "language-modeling"
def __init__(self, hparams):
super().__init__(hparams, num_labels=None, mode=self.mode, output_past=False)
super().__init__(hparams, num_labels=None, mode=self.mode)
def forward(self, input_ids, attention_mask=None, decoder_input_ids=None, lm_labels=None):
return self.model(
......
......@@ -59,7 +59,7 @@ class PretrainedConfig(object):
# Attributes with defaults
self.output_attentions = kwargs.pop("output_attentions", False)
self.output_hidden_states = kwargs.pop("output_hidden_states", False)
self.output_past = kwargs.pop("output_past", True) # Not used by all models
self.use_cache = kwargs.pop("use_cache", True) # Not used by all models
self.torchscript = kwargs.pop("torchscript", False) # Only used by PyTorch models
self.use_bfloat16 = kwargs.pop("use_bfloat16", False)
self.pruned_heads = kwargs.pop("pruned_heads", {})
......
......@@ -933,7 +933,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
return outputs
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, **kwargs):
def prepare_inputs_for_generation(self, decoder_input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step, decoder_cached_states are empty
......@@ -947,7 +947,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
"decoder_cached_states": decoder_cached_states,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"use_cache": True, # change this to avoid caching (presumably for debugging)
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}
def prepare_scores_for_generation(self, scores, cur_len, max_length):
......@@ -980,10 +980,6 @@ class BartForConditionalGeneration(PretrainedBartModel):
def get_output_embeddings(self):
return _make_linear_from_emb(self.model.shared) # make it on the fly
def _do_output_past(self, *args, **kwargs):
""" We should always use the cache in generate."""
return True
@add_start_docstrings(
"""Bart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """,
......
......@@ -98,7 +98,7 @@ class MultiHeadAttention(torch.nn.Module):
x = x.reshape(batch_size, -1, self.num_heads, self.depth)
return x.permute([0, 2, 1, 3])
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None):
def forward(self, v, k, q, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
batch_size = q.shape[0]
q = self.Wq(q)
......@@ -112,7 +112,11 @@ class MultiHeadAttention(torch.nn.Module):
past_key, past_value = layer_past[0], layer_past[1]
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
present = torch.stack((k, v))
if use_cache is True:
present = torch.stack((k, v))
else:
present = (None,)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = output[0].permute([0, 2, 1, 3])
......@@ -143,10 +147,17 @@ class EncoderLayer(torch.nn.Module):
self.dropout1 = torch.nn.Dropout(rate)
self.dropout2 = torch.nn.Dropout(rate)
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None):
def forward(self, x, mask, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
normed, normed, normed, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
normed,
normed,
normed,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output)
......@@ -199,6 +210,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
......@@ -207,8 +219,10 @@ CTRL_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
(see `past` output below). Can be used to speed up sequential decoding.
If `past` is used, the user can optionally input only the last `input_ids`
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
......@@ -219,6 +233,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -234,6 +249,10 @@ CTRL_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`.
"""
......@@ -246,7 +265,6 @@ class CTRLModel(CTRLPreTrainedModel):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
......@@ -289,6 +307,7 @@ class CTRLModel(CTRLPreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
):
r"""
Return:
......@@ -297,8 +316,7 @@ class CTRLModel(CTRLPreTrainedModel):
Sequence of hidden-states at the last layer of the model.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -325,6 +343,17 @@ class CTRLModel(CTRLPreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -414,10 +443,15 @@ class CTRLModel(CTRLPreTrainedModel):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = h(
hidden_states, mask, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
hidden_states,
mask,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
)
hidden_states, present = outputs[:2]
if self.output_past:
if use_cache is True:
presents = presents + (present,)
if self.output_attentions:
......@@ -429,7 +463,7 @@ class CTRLModel(CTRLPreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_past:
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
......@@ -462,7 +496,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past}
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def forward(
......@@ -475,6 +509,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -492,8 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -527,6 +561,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
hidden_states = transformer_outputs[0]
......
......@@ -177,7 +177,7 @@ class Attention(nn.Module):
else:
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
query = self.split_heads(query)
......@@ -187,7 +187,11 @@ class Attention(nn.Module):
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
key = torch.cat((past_key, key), dim=-1)
value = torch.cat((past_value, value), dim=-2)
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
else:
present = (None,)
attn_outputs = self._attn(query, key, value, attention_mask, head_mask)
a = attn_outputs[0]
......@@ -224,9 +228,13 @@ class Block(nn.Module):
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.mlp = MLP(4 * nx, config)
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None):
def forward(self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False):
output_attn = self.attn(
self.ln_1(x), layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask
self.ln_1(x),
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
a = output_attn[0] # output_attn: a, present, (attentions)
......@@ -279,10 +287,9 @@ GPT2_START_DOCSTRING = r"""
GPT2_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
`input_ids_length` = `sequence_length if `past` is None else 1
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
If using `past` as an input make sure that `input_ids` are those of the last position.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
......@@ -292,8 +299,8 @@ GPT2_INPUTS_DOCSTRING = r"""
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
(see `past` output below). Can be used to speed up sequential decoding.
If `past` is used, the user can optionally input only the last `input_ids` (those that don't have their past given to this model) of shape :obj:`(batch_size, 1)` instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
......@@ -305,7 +312,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
If using `past` as an input make sure that `token_type_ids` correspond to the `input_ids` of the last position.
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -321,6 +328,9 @@ GPT2_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and can be used to speed up decoding (see `past`). Defaults to `True`.
"""
......@@ -333,7 +343,6 @@ class GPT2Model(GPT2PreTrainedModel):
super().__init__(config)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
......@@ -366,16 +375,17 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
):
r"""
Return:
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.GPT2Config`) and inputs:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the last layer of the model.
If `past` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -400,6 +410,17 @@ class GPT2Model(GPT2PreTrainedModel):
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -484,11 +505,15 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),)
outputs = block(
hidden_states, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i]
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
use_cache=use_cache,
)
hidden_states, present = outputs[:2]
if self.output_past:
if use_cache is True:
presents = presents + (present,)
if self.output_attentions:
......@@ -502,7 +527,7 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_past:
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
......@@ -535,7 +560,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past": past}
return {"input_ids": input_ids, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
def forward(
......@@ -548,6 +573,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=True,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -565,8 +591,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -600,6 +625,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
hidden_states = transformer_outputs[0]
......@@ -652,6 +678,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
mc_token_ids=None,
lm_labels=None,
mc_labels=None,
use_cache=True,
):
r"""
mc_token_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, num_choices)`, `optional`, default to index of the last token of the input)
......@@ -680,8 +707,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers` with each tensor of shape :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
Can be used (see `past` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -726,6 +752,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
hidden_states = transformer_outputs[0]
......
......@@ -188,7 +188,6 @@ class T5Attention(nn.Module):
super().__init__()
self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias
self.output_past = config.output_past
self.output_attentions = config.output_attentions
self.relative_attention_num_buckets = config.relative_attention_num_buckets
......@@ -300,6 +299,7 @@ class T5Attention(nn.Module):
past_key_value_state=None,
head_mask=None,
query_length=None,
use_cache=False,
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
......@@ -351,7 +351,7 @@ class T5Attention(nn.Module):
else:
k, v = past_key_value_state
if self.is_decoder and self.output_past:
if self.is_decoder and use_cache:
present_key_value_state = ((k, v),)
else:
present_key_value_state = (None,)
......@@ -385,14 +385,8 @@ class T5Attention(nn.Module):
context = self.o(context)
outputs = (context,)
outputs = (context,) + present_key_value_state
if self.output_past is False or self.is_decoder is False:
assert (
present_key_value_state[0] is None
), "Key/Value projections should not be stored if {} is not decoder or output_past is False".format(self)
outputs = outputs + present_key_value_state
if self.output_attentions:
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
......@@ -408,7 +402,13 @@ class T5LayerSelfAttention(nn.Module):
self.dropout = nn.Dropout(config.dropout_rate)
def forward(
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, past_key_value_state=None
self,
hidden_states,
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
use_cache=False,
):
norm_x = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
......@@ -417,6 +417,7 @@ class T5LayerSelfAttention(nn.Module):
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
use_cache=use_cache,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y)
......@@ -439,6 +440,7 @@ class T5LayerCrossAttention(nn.Module):
position_bias=None,
head_mask=None,
past_key_value_state=None,
use_cache=False,
query_length=None,
):
norm_x = self.layer_norm(hidden_states)
......@@ -449,6 +451,7 @@ class T5LayerCrossAttention(nn.Module):
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
use_cache=use_cache,
query_length=query_length,
)
y = attention_output[0]
......@@ -460,7 +463,6 @@ class T5LayerCrossAttention(nn.Module):
class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False):
super().__init__()
self.output_past = config.output_past
self.is_decoder = config.is_decoder
self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
......@@ -479,6 +481,7 @@ class T5Block(nn.Module):
encoder_decoder_position_bias=None,
head_mask=None,
past_key_value_state=None,
use_cache=False,
):
if past_key_value_state is not None:
......@@ -499,6 +502,7 @@ class T5Block(nn.Module):
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state,
use_cache=use_cache,
)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
......@@ -519,6 +523,7 @@ class T5Block(nn.Module):
head_mask=head_mask,
past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length,
use_cache=use_cache,
)
hidden_states = cross_attention_outputs[0]
# Combine self attn and cross attn key value states
......@@ -620,7 +625,6 @@ class T5Stack(T5PreTrainedModel):
self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder
self.output_past = config.output_past and self.is_decoder
self.block = nn.ModuleList(
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
......@@ -648,6 +652,7 @@ class T5Stack(T5PreTrainedModel):
inputs_embeds=None,
head_mask=None,
past_key_value_states=None,
use_cache=False,
):
if input_ids is not None and inputs_embeds is not None:
......@@ -699,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
causal_mask = seq_ids[None, None, :].repeat(batch_size, mask_seq_length, 1) <= seq_ids[None, :, None]
causal_mask = causal_mask.to(attention_mask)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if self.output_past and past_key_value_states[0] is not None:
if past_key_value_states[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
......@@ -776,6 +781,7 @@ class T5Stack(T5PreTrainedModel):
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
past_key_value_state=past_key_value_state,
use_cache=use_cache,
)
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
......@@ -800,7 +806,8 @@ class T5Stack(T5PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.is_decoder and self.output_past:
if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
outputs = outputs + (present_key_value_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
......@@ -833,7 +840,7 @@ T5_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left. If `decoder_past_key_value_states` is used, optionally only the last `input_ids` have to be input (see `decoder_past_key_value_states`).
T5 is a model with relative position embeddings so you should be able to pad the inputs on both the right and the left.
Indices can be obtained using :class:`transformers.T5Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
......@@ -849,19 +856,26 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder.
decoder_input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
To know more on how to prepare :obj:`decoder_input_ids` for pre-training take a look at
`T5 Training <./t5.html#training>`_ .
decoder_attention_mask (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding. If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids` of shape :obj:`(batch_size, 1)` instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
Can be used to speed up decoding.
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
instead of all `decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
decoder_inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, target_sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`decoder_input_ids` you can choose to directly pass an embedded representation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_inputs_embeds` have to be input (see `decoder_past_key_value_states`).
This is useful if you want more control over how to convert `decoder_input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
head_mask: (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
......@@ -897,14 +911,6 @@ class T5Model(T5PreTrainedModel):
self.encoder.set_input_embeddings(new_embeddings)
self.decoder.set_input_embeddings(new_embeddings)
def set_output_past(self, do_output_past: bool):
self.config.output_past = do_output_past
self.decoder.output_past = do_output_past
for block in self.decoder.block:
block.output_past = do_output_past
block.layer[0].SelfAttention.output_past = do_output_past
block.layer[1].EncDecAttention.output_past = do_output_past
def get_encoder(self):
return self.encoder
......@@ -928,6 +934,7 @@ class T5Model(T5PreTrainedModel):
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_value_states=None,
use_cache=True,
inputs_embeds=None,
decoder_inputs_embeds=None,
head_mask=None,
......@@ -938,7 +945,7 @@ class T5Model(T5PreTrainedModel):
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
......@@ -976,7 +983,7 @@ class T5Model(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
if decoder_past_key_value_states is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
......@@ -991,9 +998,10 @@ class T5Model(T5PreTrainedModel):
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
if self.decoder.output_past:
if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
......@@ -1022,14 +1030,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
def get_input_embeddings(self):
return self.shared
def set_output_past(self, do_output_past: bool):
self.config.output_past = do_output_past
self.decoder.output_past = do_output_past
for block in self.decoder.block:
block.output_past = do_output_past
block.layer[0].SelfAttention.output_past = do_output_past
block.layer[1].EncDecAttention.output_past = do_output_past
def set_input_embeddings(self, new_embeddings):
self.shared = new_embeddings
self.encoder.set_input_embeddings(new_embeddings)
......@@ -1053,6 +1053,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
decoder_input_ids=None,
decoder_attention_mask=None,
decoder_past_key_value_states=None,
use_cache=True,
lm_labels=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
......@@ -1072,7 +1073,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
If `past_key_value_states` is used only the last prediction_scores of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``config.output_past=True``):
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
......@@ -1116,10 +1117,8 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None and self.decoder.output_past is True:
assert (
lm_labels is None
), "Decoder should not use cached key value states when training. Also consider setting model.set_output_past(False) for less memory consumption"
if decoder_past_key_value_states is not None:
assert lm_labels is None, "Decoder should not use cached key value states when training."
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
......@@ -1134,11 +1133,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
# insert decoder past at right place
# to speed up decoding
if self.decoder.output_past:
if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
......@@ -1157,7 +1157,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
......@@ -1171,13 +1171,14 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
"decoder_past_key_value_states": decoder_past_key_value_states,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"use_cache": use_cache,
}
def _reorder_cache(self, past, beam_idx):
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if len(past) < 2:
logger.warning("You might want to consider setting model.set_output_past(True) to speed up decoding")
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
decoder_past = past[1]
......
......@@ -94,7 +94,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask = inputs
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache = inputs
batch_size = shape_list(q)[0]
q = self.Wq(q)
......@@ -104,11 +104,25 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
q = self.split_into_heads(q, batch_size)
k = self.split_into_heads(k, batch_size)
v = self.split_into_heads(v, batch_size)
if layer_past is not None:
past_key, past_value = tf.unstack(layer_past, axis=0)
k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2)
present = tf.stack((k, v), axis=0)
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True:
present = tf.stack((k, v), axis=0)
else:
present = (None,)
output = scaled_dot_product_attention(q, k, v, mask, attention_mask, head_mask)
scaled_attention = tf.transpose(output[0], perm=[0, 2, 1, 3])
......@@ -147,10 +161,10 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask = inputs
x, mask, layer_past, attention_mask, head_mask, use_cache = inputs
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
[normed, normed, normed, mask, layer_past, attention_mask, head_mask], training=training
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache], training=training
)
attn_output = attn_outputs[0]
attn_output = self.dropout1(attn_output, training=training)
......@@ -173,7 +187,6 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.output_past = config.output_past
self.d_model_size = config.n_embd
self.num_layers = config.n_layer
......@@ -220,8 +233,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
training=False,
):
if isinstance(inputs, (tuple, list)):
input_ids = inputs[0]
past = inputs[1] if len(inputs) > 1 else past
......@@ -230,7 +245,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
......@@ -239,10 +255,21 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -319,10 +346,10 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i]], training=training)
outputs = h([hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache], training=training)
hidden_states, present = outputs[:2]
if self.output_past:
if use_cache is True:
presents = presents + (present,)
if self.output_attentions:
......@@ -334,7 +361,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_past:
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
......@@ -386,6 +413,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
......@@ -394,8 +422,10 @@ CTRL_INPUTS_DOCSTRING = r"""
`What are input IDs? <../glossary.html#input-ids>`__
past (:obj:`List[tf.Tensor]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
(see `past` output below). Can be used to speed up sequential decoding.
If `past` is used, the user can optionally input only the last `input_ids`
(those that don't have their past given to this model) of shape :obj:`(batch_size, 1)`
instead of all `input_ids` of shape :obj:`(batch_size, sequence_length)`.
attention_mask (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
......@@ -406,6 +436,7 @@ CTRL_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -421,6 +452,10 @@ CTRL_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
use_cache (:obj:`bool`):
If `use_cache` is True, `past` key value states are returned and
can be used to speed up decoding (see `past`). Defaults to `True`.
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
(if set to :obj:`False`) for evaluation.
......@@ -514,7 +549,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past}
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
......
......@@ -134,7 +134,7 @@ class TFAttention(tf.keras.layers.Layer):
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask = inputs
x, layer_past, attention_mask, head_mask, use_cache = inputs
x = self.c_attn(x)
query, key, value = tf.split(x, 3, axis=2)
......@@ -145,7 +145,20 @@ class TFAttention(tf.keras.layers.Layer):
past_key, past_value = tf.unstack(layer_past, axis=0)
key = tf.concat([past_key, key], axis=-2)
value = tf.concat([past_value, value], axis=-2)
present = tf.stack([key, value], axis=0)
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if use_cache is True:
present = tf.stack([key, value], axis=0)
else:
present = (None,)
attn_outputs = self._attn([query, key, value, attention_mask, head_mask], training=training)
a = attn_outputs[0]
......@@ -184,10 +197,10 @@ class TFBlock(tf.keras.layers.Layer):
self.mlp = TFMLP(4 * nx, config, name="mlp")
def call(self, inputs, training=False):
x, layer_past, attention_mask, head_mask = inputs
x, layer_past, attention_mask, head_mask, use_cache = inputs
a = self.ln_1(x)
output_attn = self.attn([a, layer_past, attention_mask, head_mask], training=training)
output_attn = self.attn([a, layer_past, attention_mask, head_mask, use_cache], training=training)
a = output_attn[0] # output_attn: a, present, (attentions)
x = x + a
......@@ -245,6 +258,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
training=False,
):
if isinstance(inputs, (tuple, list)):
......@@ -255,7 +269,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids = inputs[4] if len(inputs) > 4 else position_ids
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
assert len(inputs) <= 7, "Too many inputs."
use_cache = inputs[7] if len(inputs) > 7 else use_cache
assert len(inputs) <= 8, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
......@@ -264,10 +279,21 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
position_ids = inputs.get("position_ids", position_ids)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 7, "Too many inputs."
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 8, "Too many inputs."
else:
input_ids = inputs
# If using past key value states, only the last tokens
# should be given as an input
if past is not None:
if input_ids is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None:
inputs_embeds = inputs_embeds[:, -1:]
if token_type_ids is not None:
token_type_ids = token_type_ids[:, -1:]
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
......@@ -338,7 +364,7 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i]], training=training)
outputs = block([hidden_states, layer_past, attention_mask, head_mask[i], use_cache], training=training)
hidden_states, present = outputs[:2]
presents = presents + (present,)
......@@ -353,7 +379,10 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states, presents)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
......@@ -404,6 +433,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`Numpy array` or :obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary.
If `past` is used, optionally only the last `input_ids` have to be input (see `past`).
Indices can be obtained using :class:`transformers.GPT2Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
......@@ -424,6 +454,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
If `past` is used, optionally only the last `token_type_ids` have to be input (see `past`).
`What are token type IDs? <../glossary.html#token-type-ids>`_
position_ids (:obj:`tf.Tensor` or :obj:`Numpy array` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
......@@ -439,6 +470,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
If `past` is used, optionally only the last `input_embeds` have to be input (see `past`).
training (:obj:`boolean`, `optional`, defaults to :obj:`False`):
Whether to activate dropout modules (if set to :obj:`True`) during training or to de-activate them
(if set to :obj:`False`) for evaluation.
......@@ -511,7 +543,7 @@ class TFGPT2LMHeadModel(TFGPT2PreTrainedModel):
if past:
inputs = tf.expand_dims(inputs[:, -1], -1)
return {"inputs": inputs, "past": past}
return {"inputs": inputs, "past": past, "use_cache": kwargs["use_cache"]}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
......@@ -590,6 +622,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask=None,
inputs_embeds=None,
mc_token_ids=None,
use_cache=True,
training=False,
):
r"""
......@@ -656,7 +689,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask = inputs[5] if len(inputs) > 5 else head_mask
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
mc_token_ids = inputs[7] if len(inputs) > 7 else mc_token_ids
assert len(inputs) <= 8, "Too many inputs."
use_cache = inputs[8] if len(inputs) > 8 else use_cache
assert len(inputs) <= 9, "Too many inputs."
elif isinstance(inputs, dict):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
......@@ -666,7 +700,8 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
mc_token_ids = inputs.get("mc_token_ids", mc_token_ids)
assert len(inputs) <= 8, "Too many inputs."
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 9, "Too many inputs."
else:
input_ids = inputs
......@@ -690,6 +725,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
flat_position_ids,
head_mask,
inputs_embeds,
use_cache,
]
transformer_outputs = self.transformer(flat_inputs, training=training)
......
......@@ -444,16 +444,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
def prepare_inputs_for_generation(self, inputs, **kwargs):
return {"inputs": inputs}
def _do_output_past(self, outputs):
has_output_past = hasattr(self.config, "output_past") and self.config.output_past
has_mem_len = hasattr(self.config, "mem_len") and self.config.mem_len
if has_output_past and not has_mem_len and len(outputs) > 1:
return True
elif has_mem_len and self.config.mem_len > 0 and len(outputs) > 1:
return True
return False
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
if len(outputs) <= 1 or use_cache is False:
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True
def generate(
self,
......@@ -476,6 +473,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy or penalized greedy decoding, sampling with top-k or nucleus sampling
and beam-search.
......@@ -551,6 +549,9 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
If an encoder-decoder model starts decoding with a different token than BOS.
Defaults to `None` and is changed to `BOS` later.
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
Return:
output: `tf.Tensor` of `dtype=tf.int32` shape `(batch_size * num_return_sequences, sequence_length)`
......@@ -605,6 +606,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
......@@ -634,6 +636,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictely positive integer."
assert temperature > 0, "`temperature` should be strictely positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
......@@ -782,6 +785,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
)
else:
output = self._generate_no_beam_search(
......@@ -804,6 +808,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
)
return output
......@@ -829,6 +834,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
......@@ -841,12 +847,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
if self._use_cache(outputs, use_cache):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
......@@ -993,6 +1001,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
):
""" Generate sequences for each example with beam search.
"""
......@@ -1020,12 +1029,14 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
if self._use_cache(outputs, use_cache):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
......
......@@ -358,7 +358,6 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
......@@ -503,6 +502,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
training=False,
):
if isinstance(inputs, (tuple, list)):
......@@ -515,7 +515,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask = inputs[6] if len(inputs) > 6 else input_mask
head_mask = inputs[7] if len(inputs) > 7 else head_mask
inputs_embeds = inputs[8] if len(inputs) > 8 else inputs_embeds
assert len(inputs) <= 9, "Too many inputs."
use_cache = inputs[9] if len(inputs) > 9 else use_cache
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
attention_mask = inputs.get("attention_mask", attention_mask)
......@@ -526,7 +527,8 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
input_mask = inputs.get("input_mask", input_mask)
head_mask = inputs.get("head_mask", head_mask)
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
assert len(inputs) <= 9, "Too many inputs."
use_cache = inputs.get("use_cache", use_cache)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
......@@ -657,7 +659,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
hidden_states = []
for i, layer_module in enumerate(self.layer):
# cache new mems
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
......@@ -679,7 +681,7 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (tf.transpose(output, perm=(1, 0, 2)),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
outputs = outputs + (new_mems,)
if self.output_hidden_states:
......@@ -783,6 +785,8 @@ XLNET_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (:obj:`bool`):
If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
"""
......@@ -848,7 +852,7 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss.input_embeddings
def prepare_inputs_for_generation(self, inputs, past, **model_kwargs):
def prepare_inputs_for_generation(self, inputs, past, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = inputs.shape[0]
......@@ -866,7 +870,12 @@ class TFXLNetLMHeadModel(TFXLNetPreTrainedModel):
target_mapping_seq_end = tf.ones((effective_batch_size, 1, 1), dtype=tf.float32)
target_mapping = tf.concat([target_mapping, target_mapping_seq_end], axis=-1)
inputs = {"inputs": inputs, "perm_mask": perm_mask, "target_mapping": target_mapping}
inputs = {
"inputs": inputs,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
}
# if past is defined in model kwargs then use it for faster decoding
if past:
......
......@@ -652,15 +652,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
def prepare_scores_for_generation(self, scores, **kwargs):
return scores
def _do_output_past(self, outputs):
def _use_cache(self, outputs, use_cache):
"""During generation, decide whether to pass the `past` variable to the next forward pass."""
has_output_past = getattr(self.config, "output_past", False)
mem_len = getattr(self.config, "mem_len", 0)
if len(outputs) <= 1:
if len(outputs) <= 1 or use_cache is False:
return False
if mem_len > 0 or has_output_past:
return True
return False
if hasattr(self.config, "mem_len") and self.config.mem_len == 0:
return False
return True
def enforce_repetition_penalty_(self, lprobs, batch_size, num_beams, prev_output_tokens, repetition_penalty):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
......@@ -694,6 +692,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
num_return_sequences=None,
attention_mask=None,
decoder_start_token_id=None,
use_cache=None,
):
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
......@@ -768,6 +767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
If an encoder-decoder model starts decoding with a different token than BOS.
Defaults to `None` and is changed to `BOS` later.
use_cache: (`optional`) bool
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
Return:
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
......@@ -822,6 +824,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
min_length = min_length if min_length is not None else self.config.min_length
do_sample = do_sample if do_sample is not None else self.config.do_sample
early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping
use_cache = use_cache if use_cache is not None else self.config.use_cache
num_beams = num_beams if num_beams is not None else self.config.num_beams
temperature = temperature if temperature is not None else self.config.temperature
top_k = top_k if top_k is not None else self.config.top_k
......@@ -851,6 +854,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
assert isinstance(min_length, int) and min_length >= 0, "`min_length` should be a positive integer."
assert isinstance(do_sample, bool), "`do_sample` should be a boolean."
assert isinstance(early_stopping, bool), "`early_stopping` should be a boolean."
assert isinstance(use_cache, bool), "`use_cache` should be a boolean."
assert isinstance(num_beams, int) and num_beams > 0, "`num_beams` should be a strictly positive integer."
assert temperature > 0, "`temperature` should be strictly positive."
assert isinstance(top_k, int) and top_k >= 0, "`top_k` should be a positive integer."
......@@ -1011,6 +1015,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
vocab_size=vocab_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
)
else:
output = self._generate_no_beam_search(
......@@ -1032,6 +1037,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size=effective_batch_size,
encoder_outputs=encoder_outputs,
attention_mask=attention_mask,
use_cache=use_cache,
)
return output
......@@ -1056,6 +1062,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
batch_size,
encoder_outputs,
attention_mask,
use_cache,
):
""" Generate sequences for each example without beam search (num_beams == 1).
All returned sequence are generated independantly.
......@@ -1067,13 +1074,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
past = encoder_outputs # defined for encoder-decoder models, None for decoder-only models
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs)
next_token_logits = outputs[0][:, -1, :]
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
if self._use_cache(outputs, use_cache):
past = outputs[1]
# repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858)
......@@ -1178,6 +1187,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
vocab_size,
encoder_outputs,
attention_mask,
use_cache,
):
""" Generate sequences for each example with beam search.
"""
......@@ -1203,12 +1213,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
done = [False for _ in range(batch_size)]
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(input_ids, past=past, attention_mask=attention_mask)
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
)
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
# if model has past, then set the past variable to speed up decoding
if self._do_output_past(outputs):
if self._use_cache(outputs, use_cache):
past = outputs[1]
# repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858)
......
......@@ -524,6 +524,7 @@ XLNET_INPUTS_DOCSTRING = r"""
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `mems` output below). Can be used to speed up sequential decoding. The token ids which have their mems
given to this model should not be passed as input ids as they have already been computed.
`use_cache` has to be set to `True` to make use of `mems`.
perm_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to indicate the attention pattern for each input token with values selected in ``[0, 1]``:
If ``perm_mask[k, i, j] = 0``, i attend to j in batch k;
......@@ -555,6 +556,8 @@ XLNET_INPUTS_DOCSTRING = r"""
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
use_cache (:obj:`bool`):
If `use_cache` is True, `mems` are returned and can be used to speed up decoding (see `mems`). Defaults to `True`.
"""
......@@ -567,7 +570,6 @@ class XLNetModel(XLNetPreTrainedModel):
super().__init__(config)
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.output_past = config.output_past
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
......@@ -698,6 +700,7 @@ class XLNetModel(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
):
r"""
Return:
......@@ -864,7 +867,7 @@ class XLNetModel(XLNetPreTrainedModel):
attentions = []
hidden_states = []
for i, layer_module in enumerate(self.layer):
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
# cache new mems
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
if self.output_hidden_states:
......@@ -894,7 +897,7 @@ class XLNetModel(XLNetPreTrainedModel):
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
outputs = (output.permute(1, 0, 2).contiguous(),)
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
outputs = outputs + (new_mems,)
if self.output_hidden_states:
......@@ -935,7 +938,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
def get_output_embeddings(self):
return self.lm_loss
def prepare_inputs_for_generation(self, input_ids, past, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# Add dummy token at the end (no attention on this one)
effective_batch_size = input_ids.shape[0]
......@@ -955,7 +958,12 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
)
target_mapping[0, 0, -1] = 1.0
inputs = {"input_ids": input_ids, "perm_mask": perm_mask, "target_mapping": target_mapping}
inputs = {
"input_ids": input_ids,
"perm_mask": perm_mask,
"target_mapping": target_mapping,
"use_cache": kwargs["use_cache"],
}
# if past is defined in model kwargs then use it for faster decoding
if past:
......@@ -975,6 +983,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
):
r"""
......@@ -1050,6 +1059,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
logits = self.lm_loss(transformer_outputs[0])
......@@ -1093,6 +1103,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
):
r"""
......@@ -1148,6 +1159,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
output = transformer_outputs[0]
......@@ -1196,6 +1208,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
):
r"""
......@@ -1252,6 +1265,7 @@ class XLNetForTokenClassification(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
sequence_output = outputs[0]
......@@ -1301,9 +1315,10 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
mems=None,
perm_mask=None,
target_mapping=None,
labels=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
labels=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
......@@ -1368,6 +1383,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
target_mapping=target_mapping,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
output = transformer_outputs[0]
......@@ -1414,6 +1430,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
start_positions=None,
end_positions=None,
):
......@@ -1478,6 +1495,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
sequence_output = outputs[0]
......@@ -1538,6 +1556,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_cache=True,
start_positions=None,
end_positions=None,
is_impossible=None,
......@@ -1616,6 +1635,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
)
hidden_states = transformer_outputs[0]
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
......
......@@ -128,7 +128,6 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config.output_attentions = True
config.output_hidden_states = False
config.output_past = False
model = model_class(config)
model.to(torch_device)
model.eval()
......@@ -664,10 +663,6 @@ class ModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
if self.is_encoder_decoder:
# needed for Bart beam search
config.output_past = True
for model_class in self.all_generative_model_classes:
model = model_class(config)
......
......@@ -227,7 +227,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model.eval()
# first forward pass
output, past_key_value_states = model(input_ids)
output, past_key_value_states = model(input_ids, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
......@@ -235,8 +235,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past, _ = model(next_input_ids)
output_from_past, _ = model(next_tokens, past_key_value_states=past_key_value_states)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
......@@ -260,7 +260,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past_key_value_states = model(input_ids, attention_mask=attn_mask)
output, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
......@@ -277,10 +277,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
)
# get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
output_from_past, _ = model(
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
)
)[0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
......@@ -298,10 +298,10 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model.to(torch_device)
model.eval()
torch.manual_seed(0)
model.set_output_past(False)
output_without_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
output_without_past_cache = model.generate(
input_ids[:1], num_beams=2, max_length=5, do_sample=True, use_cache=False
)
torch.manual_seed(0)
model.set_output_past(True)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
......@@ -321,6 +321,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"use_cache": False,
}
return config, inputs_dict
......
......@@ -460,10 +460,6 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict["input_ids"] if "input_ids" in inputs_dict else inputs_dict["inputs"]
if self.is_encoder_decoder:
# needed for Bart beam search
config.output_past = True
for model_class in self.all_generative_model_classes:
model = model_class(config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment