Unverified Commit 38f7461d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TFT5, Cache] Add cache to TFT5 (#3772)

* correct gpt2 test inputs

* make style

* delete modeling_gpt2 change in test file

* translate from pytorch

* correct tests

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* make tensorflow t5 caching work

* make style

* clean reorder cache

* remove unnecessary spaces

* fix test
parent a5b24947
......@@ -351,12 +351,11 @@ class T5Attention(nn.Module):
else:
k, v = past_key_value_state
if self.is_decoder and use_cache:
if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),)
else:
present_key_value_state = (None,)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
if position_bias is None:
......@@ -486,11 +485,15 @@ class T5Block(nn.Module):
if past_key_value_state is not None:
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
assert (
len(past_key_value_state) == 4
), "The should be 4 past states. 2 (past / key) for self attention. 2 (past / key) for cross attention. Got {} past key / value states".format(
len(past_key_value_state)
expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
expected_num_past_key_value_states,
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
len(past_key_value_state),
)
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:]
else:
......@@ -507,7 +510,7 @@ class T5Block(nn.Module):
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
if self.is_decoder:
if self.is_decoder and encoder_hidden_states is not None:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
......@@ -691,7 +694,6 @@ class T5Stack(T5PreTrainedModel):
if past_key_value_states is None:
past_key_value_states = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
......@@ -732,7 +734,7 @@ class T5Stack(T5PreTrainedModel):
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder:
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
......
......@@ -185,16 +185,39 @@ class TFT5Attention(tf.keras.layers.Layer):
return values
def call(
self, input, mask=None, kv=None, position_bias=None, cache=None, head_mask=None, training=False,
self,
input,
mask=None,
kv=None,
position_bias=None,
cache=None,
past_key_value_state=None,
head_mask=None,
query_length=None,
use_cache=False,
training=False,
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
# past_key_value_state[0] is (bs, n_heads, q_len - 1, dim_per_head)
bs, qlen, dim = shape_list(input)
if past_key_value_state is not None:
assert self.is_decoder is True, "Encoder cannot cache past key value states"
assert (
len(past_key_value_state) == 2
), "past_key_value_state should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value_state)
)
real_qlen = qlen + shape_list(past_key_value_state[0])[2] if query_length is None else query_length
else:
real_qlen = qlen
if kv is None:
klen = qlen if cache is None else cache["slen"] + qlen
klen = real_qlen
else:
klen = shape_list(kv)[1]
......@@ -207,36 +230,51 @@ class TFT5Attention(tf.keras.layers.Layer):
return tf.reshape(tf.transpose(x, perm=(0, 2, 1, 3)), (bs, -1, self.inner_dim))
q = shape(self.q(input)) # (bs, n_heads, qlen, dim_per_head)
if kv is None:
k = shape(self.k(input)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(input)) # (bs, n_heads, qlen, dim_per_head)
elif cache is None or self.layer_id not in cache:
elif past_key_value_state is None:
k = v = kv
k = shape(self.k(k)) # (bs, n_heads, qlen, dim_per_head)
v = shape(self.v(v)) # (bs, n_heads, qlen, dim_per_head)
if cache is not None:
if self.layer_id in cache:
if past_key_value_state is not None:
if kv is None:
k_, v_ = cache[self.layer_id]
k_, v_ = past_key_value_state
k = tf.concat([k_, k], axis=2) # (bs, n_heads, klen, dim_per_head)
v = tf.concat([v_, v], axis=2) # (bs, n_heads, klen, dim_per_head)
else:
k, v = cache[self.layer_id]
cache[self.layer_id] = (k, v)
k, v = past_key_value_state
# to cope with keras serialization
# we need to cast `use_cache` to correct bool
# if it is a tensor
if tf.is_tensor(use_cache):
if hasattr(use_cache, "numpy"):
use_cache = bool(use_cache.numpy())
else:
use_cache = True
if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),)
else:
present_key_value_state = (None,)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
# scores = tf.matmul(q, k, transpose_b=True) # (bs, n_heads, qlen, klen)
scores = tf.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
if position_bias is None:
if not self.has_relative_attention_bias:
raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen)
position_bias = self.compute_bias(real_qlen, klen)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value_state is not None:
position_bias = position_bias[:, :, -1:, :]
if mask is not None:
position_bias = position_bias + mask
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
scores += position_bias
weights = tf.nn.softmax(scores, axis=-1) # (bs, n_heads, qlen, klen)
......@@ -251,7 +289,8 @@ class TFT5Attention(tf.keras.layers.Layer):
context = self.o(context)
outputs = (context,)
outputs = (context,) + present_key_value_state
if self.output_attentions:
outputs = outputs + (weights,)
if self.has_relative_attention_bias:
......@@ -269,11 +308,24 @@ class TFT5LayerSelfAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(
self, hidden_states, attention_mask=None, position_bias=None, head_mask=None, training=False,
self,
hidden_states,
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
use_cache=False,
training=False,
):
norm_x = self.layer_norm(hidden_states)
attention_output = self.SelfAttention(
norm_x, mask=attention_mask, position_bias=position_bias, head_mask=head_mask, training=training,
norm_x,
mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
use_cache=use_cache,
training=training,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y, training=training)
......@@ -291,11 +343,28 @@ class TFT5LayerCrossAttention(tf.keras.layers.Layer):
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
def call(
self, hidden_states, kv, attention_mask=None, position_bias=None, head_mask=None, training=False,
self,
hidden_states,
kv,
attention_mask=None,
position_bias=None,
head_mask=None,
past_key_value_state=None,
query_length=None,
use_cache=False,
training=False,
):
norm_x = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention(
norm_x, mask=attention_mask, kv=kv, position_bias=position_bias, head_mask=head_mask, training=training,
norm_x,
mask=attention_mask,
kv=kv,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=past_key_value_state,
query_length=query_length,
use_cache=use_cache,
training=training,
)
y = attention_output[0]
layer_output = hidden_states + self.dropout(y, training=training)
......@@ -317,9 +386,8 @@ class TFT5Block(tf.keras.layers.Layer):
config, has_relative_attention_bias=has_relative_attention_bias, name="layer_._1",
)
)
self.layer.append(TFT5LayerFF(config, name="layer_._2"))
else:
self.layer.append(TFT5LayerFF(config, name="layer_._1"))
self.layer.append(TFT5LayerFF(config, name="layer_._{}".format(len(self.layer))))
def call(
self,
......@@ -330,35 +398,73 @@ class TFT5Block(tf.keras.layers.Layer):
encoder_attention_mask=None,
encoder_decoder_position_bias=None,
head_mask=None,
past_key_value_state=None,
use_cache=False,
training=False,
):
if past_key_value_state is not None:
assert self.is_decoder, "Only decoder can use `past_key_value_states`"
expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
expected_num_past_key_value_states,
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
len(past_key_value_state),
)
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:]
else:
self_attn_past_key_value_state, cross_attn_past_key_value_state = None, None
self_attention_outputs = self.layer[0](
hidden_states,
attention_mask=attention_mask,
position_bias=position_bias,
head_mask=head_mask,
past_key_value_state=self_attn_past_key_value_state,
use_cache=use_cache,
training=training,
)
hidden_states = self_attention_outputs[0]
outputs = self_attention_outputs[1:]
if not self.is_decoder:
hidden_states = self.layer[1](hidden_states, training=training)
hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
if self.is_decoder and encoder_hidden_states is not None:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = shape_list(present_key_value_state[0])[2]
else:
query_length = None
cross_attention_outputs = self.layer[1](
hidden_states,
kv=encoder_hidden_states,
attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias,
head_mask=head_mask,
past_key_value_state=cross_attn_past_key_value_state,
query_length=query_length,
use_cache=use_cache,
training=training,
)
hidden_states = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:]
hidden_states = self.layer[2](hidden_states, training=training)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:]
# Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states, training=training)
outputs = (hidden_states,)
outputs = (hidden_states,) + outputs # add attentions if we output them
return outputs # hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
# Add attentions if we output them
outputs = outputs + (present_key_value_state,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
class _NoLayerEmbedTokens(object):
......@@ -437,6 +543,8 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=None,
inputs_embeds=None,
head_mask=None,
past_key_value_states=None,
use_cache=False,
training=False,
):
......@@ -456,12 +564,26 @@ class TFT5MainLayer(tf.keras.layers.Layer):
batch_size, seq_length = input_shape
if past_key_value_states is not None:
assert seq_length == 1, "Input shape is {}, but should be {} when using past_key_value_sates".format(
input_shape, (batch_size, 1)
)
# required mask seq length can be calculated via length of past
# key value states and seq_length = 1 for the last token
mask_seq_length = shape_list(past_key_value_states[0][0])[2] + seq_length
else:
mask_seq_length = seq_length
if attention_mask is None:
attention_mask = tf.fill((batch_size, seq_length), 1)
if self.is_decoder and encoder_attention_mask is None:
encoder_seq_length = encoder_hidden_states.shape[1]
attention_mask = tf.fill((batch_size, mask_seq_length), 1)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = shape_list(encoder_hidden_states)[1]
encoder_attention_mask = tf.fill((batch_size, encoder_seq_length), 1)
# initialize past_key_value_states with `None` if past does not exist
if past_key_value_states is None:
past_key_value_states = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
......@@ -469,16 +591,18 @@ class TFT5MainLayer(tf.keras.layers.Layer):
if num_dims_attention_mask == 3:
extended_attention_mask = attention_mask[:, None, :, :]
elif num_dims_attention_mask == 2:
# Provided a padding mask of dimensions [batch_size, seq_length]
# Provided a padding mask of dimensions [batch_size, mask_seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder:
seq_ids = tf.range(seq_length)
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length]
if self.is_decoder:
seq_ids = tf.range(mask_seq_length)
causal_mask = tf.less_equal(
tf.tile(seq_ids[None, None, :], (batch_size, seq_length, 1)), seq_ids[None, :, None],
tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), seq_ids[None, :, None],
)
causal_mask = tf.cast(causal_mask, dtype=tf.float32)
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
if past_key_value_states[0] is not None:
extended_attention_mask = extended_attention_mask[:, :, -1:, :]
else:
extended_attention_mask = attention_mask[:, None, None, :]
......@@ -495,8 +619,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
extended_attention_mask = (1.0 - extended_attention_mask) * -1e9
if self.is_decoder:
if self.is_decoder and encoder_attention_mask is not None:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, mask_seq_length, mask_seq_length]
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=tf.float32)
num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask))
......@@ -525,13 +650,15 @@ class TFT5MainLayer(tf.keras.layers.Layer):
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
present_key_value_states = ()
all_hidden_states = ()
all_attentions = ()
position_bias = None
encoder_decoder_position_bias = None
hidden_states = inputs_embeds
for i, layer_module in enumerate(self.block):
hidden_states = self.dropout(inputs_embeds, training=training)
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
......@@ -543,18 +670,24 @@ class TFT5MainLayer(tf.keras.layers.Layer):
encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias,
head_mask=head_mask[i],
past_key_value_state=past_key_value_state,
use_cache=use_cache,
training=training,
)
hidden_states = layer_outputs[0]
# layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
hidden_states, present_key_value_state = layer_outputs[:2]
if i == 0:
# We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[2 if self.output_attentions else 1]
if self.is_decoder:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 2]
position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
# append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,)
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
all_attentions = all_attentions + (layer_outputs[2],)
hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states, training=training)
......@@ -564,6 +697,9 @@ class TFT5MainLayer(tf.keras.layers.Layer):
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
outputs = outputs + (present_key_value_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
......@@ -650,6 +786,7 @@ T5_INPUTS_DOCSTRING = r"""
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
decoder_input_ids (:obj:`tf.Tensor` of shape :obj:`(batch_size, target_sequence_length)`, `optional`, defaults to :obj:`None`):
Provide for sequence to sequence training. T5 uses the pad_token_id as the starting token for decoder_input_ids generation.
If `decoder_past_key_value_states` is used, optionally only the last `decoder_input_ids` have to be input (see `decoder_past_key_value_states`).
attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
......@@ -660,6 +797,13 @@ T5_INPUTS_DOCSTRING = r"""
Used in the cross-attention of the decoder.
decoder_attention_mask (:obj:`tf.Tensor` of shape :obj:`(batch_size, tgt_seq_len)`, `optional`, defaults to :obj:`None`):
Default behavior: generate a tensor that ignores pad tokens in decoder_input_ids. Causal mask will also be used by default.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up decoding.
If `decoder_past_key_value_states` are used, the user can optionally input only the last `decoder_input_ids`
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
If `use_cache` is True, `decoder_past_key_value_states` are returned and can be used to speed up decoding (see `decoder_past_key_value_states`).
inputs_embeds (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
......@@ -705,6 +849,12 @@ class TFT5Model(TFT5PreTrainedModel):
def get_output_embeddings(self):
return self.shared
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
......@@ -712,6 +862,11 @@ class TFT5Model(TFT5PreTrainedModel):
:obj:`tuple(tf.Tensor)` comprising various elements depending on the configuration (:class:`~transformers.T5Config`) and inputs.
last_hidden_state (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If `decoder_past_key_value_states` is used only the last hidden-state of the sequences of shape :obj:`(batch_size, 1, hidden_size)` is output.
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `hidden-state` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -743,12 +898,14 @@ class TFT5Model(TFT5PreTrainedModel):
# retrieve arguments
input_ids = kwargs.get("inputs", None)
decoder_input_ids = kwargs.get("decoder_input_ids", None)
inputs_embeds = kwargs.get("inputs_embeds", None)
attention_mask = kwargs.get("attention_mask", None)
encoder_outputs = kwargs.get("encoder_outputs", None)
decoder_input_ids = kwargs.get("decoder_input_ids", None)
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
inputs_embeds = kwargs.get("inputs_embeds", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
use_cache = kwargs.get("use_cache", True)
head_mask = kwargs.get("head_mask", None)
# Encode if needed (training, first prediction pass)
......@@ -759,16 +916,30 @@ class TFT5Model(TFT5PreTrainedModel):
hidden_states = encoder_outputs[0]
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Decode
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_value_states,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
return decoder_outputs + encoder_outputs
......@@ -802,6 +973,9 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
def get_encoder(self):
return self.encoder
def get_decoder(self):
return self.decoder
@add_start_docstrings_to_callable(T5_INPUTS_DOCSTRING)
def call(self, inputs, **kwargs):
r"""
......@@ -811,6 +985,10 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
Classification loss (cross entropy).
prediction_scores (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
decoder_past_key_value_states (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length, embed_size_per_head)`, `optional`, returned when ``use_cache=True``):
Contains pre-computed key and value hidden-states of the attention blocks.
Can be used to speed up sequential decoding (see `decoder_past_key_value_states` input).
Note that when using `decoder_past_key_value_states`, the model only outputs the last `prediction_score` of the sequence of shape :obj:`(batch_size, 1, config.vocab_size)`.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
......@@ -850,6 +1028,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
attention_mask = kwargs.get("attention_mask", None)
encoder_outputs = kwargs.get("encoder_outputs", None)
decoder_attention_mask = kwargs.get("decoder_attention_mask", None)
decoder_past_key_value_states = kwargs.get("decoder_past_key_value_states", None)
use_cache = kwargs.get("use_cache", True)
inputs_embeds = kwargs.get("inputs_embeds", None)
decoder_inputs_embeds = kwargs.get("decoder_inputs_embeds", None)
head_mask = kwargs.get("head_mask", None)
......@@ -863,16 +1043,32 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
hidden_states = encoder_outputs[0]
# If decoding with past key value states, only the last tokens
# should be given as an input
if decoder_past_key_value_states is not None:
if decoder_input_ids is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
if decoder_inputs_embeds is not None:
decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
# Decode
decoder_outputs = self.decoder(
decoder_input_ids,
attention_mask=decoder_attention_mask,
inputs_embeds=decoder_inputs_embeds,
past_key_value_states=decoder_past_key_value_states,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
)
# insert decoder past at right place
# to speed up decoding
if use_cache is True:
past = ((encoder_outputs, decoder_outputs[1]),)
decoder_outputs = decoder_outputs[:1] + past + decoder_outputs[2:]
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
embed_tokens = self.get_output_embeddings()
lm_logits = embed_tokens(sequence_output, mode="linear")
......@@ -880,22 +1076,46 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel):
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, use_cache, **kwargs):
assert past is not None, "past has to be defined for encoder_outputs"
# first step
if type(past) is tuple:
encoder_outputs = past
if len(past) < 2:
encoder_outputs, decoder_past_key_value_states = past, None
else:
encoder_outputs = (past,)
encoder_outputs, decoder_past_key_value_states = past[0], past[1]
return {
"inputs": None, # inputs don't have to be defined, but still need to be passed to make Keras.layer.__call__ happy
"decoder_input_ids": input_ids, # input_ids are the decoder_input_ids
"decoder_past_key_value_states": decoder_past_key_value_states,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"use_cache": use_cache,
}
def _reorder_cache(self, past, beam_idx):
# past does not have to be re-ordered for T5.
# if decoder past is not included in output
# speedy decoding is disabled and no need to reorder
if len(past) < 2:
logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
return past
decoder_past = past[1]
past = (past[0],)
reordered_decoder_past = ()
for layer_past_states in decoder_past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` is at 2nd position
reordered_layer_past_states = ()
for layer_past_state in layer_past_states:
# need to set correct `past` for each of the four key / value states
reordered_layer_past_states = reordered_layer_past_states + (tf.gather(layer_past_state, beam_idx),)
assert shape_list(reordered_layer_past_states[0]) == shape_list(layer_past_states[0])
assert len(reordered_layer_past_states) == len(layer_past_states)
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
return past + (reordered_decoder_past,)
......@@ -1299,17 +1299,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
......
......@@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-6))
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
......@@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
def create_t5_and_check_t5_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
config.num_layers = 1
model = T5ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
......
......@@ -191,7 +191,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12)
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_gpt2_model_attention_mask_past(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args
......
......@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_tf, slow
if is_tf_available():
import tensorflow as tf
from transformers import TFT5Model, TFT5ForConditionalGeneration, T5Tokenizer
......@@ -111,14 +112,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
}
encoder_output, decoder_output = model(inputs)
decoder_output, decoder_past, encoder_output = model(inputs)
encoder_output, decoder_output = model(
decoder_output, decoder_past, encoder_output = model(
input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids
)
result = {
"encoder_output": encoder_output.numpy(),
"decoder_past": decoder_past,
"decoder_output": decoder_output.numpy(),
}
self.parent.assertListEqual(
......@@ -127,6 +128,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(
list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
)
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5ForConditionalGeneration(config=config)
......@@ -136,7 +144,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"decoder_attention_mask": input_mask,
}
prediction_scores, decoder_output = model(inputs_dict)
prediction_scores, _, _ = model(inputs_dict)
result = {
"prediction_scores": prediction_scores.numpy(),
......@@ -145,6 +153,76 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
)
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
model = TFT5Model(config=config).get_decoder()
input_ids = input_ids[:1, :]
self.batch_size = 1
# first forward pass
_, past_key_value_states = model(input_ids, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask
):
model = TFT5Model(config=config).get_decoder()
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
_, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat([attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], axis=1,)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
)[0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs
......@@ -152,6 +230,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"inputs": input_ids,
"decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask,
"use_cache": tf.convert_to_tensor([False]),
}
return config, inputs_dict
......@@ -170,6 +249,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in ["t5-small"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment