Unverified Commit 70203b59 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF generate refactor - past without encoder outputs (#15944)

* Remove packed past from generation_tf_utils

* update models with the new past format

* update template accordingly
parent 62d84760
......@@ -1777,7 +1777,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
{% else %}
import random
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple, Union
import tensorflow as tf
......@@ -2736,9 +2736,6 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
if inputs["output_hidden_states"]:
all_hidden_states += (hidden_states,)
if inputs["use_cache"]:
present_key_values = (inputs["encoder_hidden_states"], present_key_values)
if not inputs["return_dict"]:
return hidden_states, present_key_values, all_hidden_states, all_self_attns, all_cross_attns
else:
......@@ -3186,43 +3183,23 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past,
attention_mask,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=False,
use_cache=None,
encoder_outputs=None,
**kwargs
) -> Dict:
assert past is not None and len(past) in {1, 2}, f"past has to be an iterable of length 1,2 got {past}"
if len(past) == 1:
assert isinstance(past[0], tf.Tensor), f"`past[0]` has to be of type `tf.Tensor`, but is {type(past[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=past[0])
past_key_values = None
else:
assert (
len(past) == 2
), "`past` has to be of length 2 with the encoder_outputs at the first position and past_key_values at the second position."
encoder_outputs, past_key_values = past
if isinstance(encoder_outputs, tuple):
assert isinstance(
encoder_outputs[0], tf.Tensor
), f"`encoder_outputs[0]` has to be of type `tf.Tensor`, but is {type(encoder_outputs[0])}"
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs[0])
elif isinstance(encoder_outputs, tf.Tensor):
encoder_outputs = TFBaseModelOutput(last_hidden_state=encoder_outputs)
assert (
past_key_values
), f"decoder cached states must be truthy. got {past_key_values} from the 2nd element of past"
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]
assert isinstance(
encoder_outputs, TFBaseModelOutput
), f"encoder_outputs should be a TFBaseModelOutput, Instead got {type(encoder_outputs)}."
return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"input_ids": None, # needs to be passed to make Keras.layer.__call__ happy
"encoder_outputs": encoder_outputs,
"past_key_values": past_key_values,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
......@@ -3233,17 +3210,10 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec
@staticmethod
def _reorder_cache(past, beam_idx):
if len(past) == 1:
return past
past_key_values = past[1]
reordered_past = ()
for layer_past_key_values in past_key_values:
reordered_past += (
tuple(tf.gather(layer_past_key_value, beam_idx) for layer_past_key_value in layer_past_key_values[:2]) + layer_past_key_values[2:],
)
return (past[0], reordered_past)
for layer_past in past:
reordered_past += (tuple(tf.gather(past_state, beam_idx, axis=0) for past_state in layer_past),)
return reordered_past
def hf_compute_loss(self, labels, logits):
"""CrossEntropyLoss that ignores pad tokens"""
......
......@@ -802,7 +802,6 @@ class TF{{cookiecutter.camelcase_modelname}}ModelTester:
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......
......@@ -116,7 +116,6 @@ class TFBartModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......
......@@ -114,7 +114,6 @@ class TFBlenderbotModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......
......@@ -114,7 +114,6 @@ class TFBlenderbotSmallModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......
......@@ -133,7 +133,6 @@ class TFLEDModelTester:
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......
......@@ -116,7 +116,6 @@ class TFMarianModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......
......@@ -114,7 +114,6 @@ class TFPegasusModelTester:
outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True)
output, past_key_values = outputs.to_tuple()
past_key_values = past_key_values[1]
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
......
......@@ -182,7 +182,7 @@ class TFSpeech2TextModelTester:
# first forward pass
outputs = model(input_ids, attention_mask=attention_mask, use_cache=True)
_, (_, past_key_values) = outputs.to_tuple()
_, past_key_values = outputs.to_tuple()
# create hypothetical multiple next token and extent to next_input_ids
next_tokens = tf.math.maximum(ids_tensor((self.batch_size, 3), config.vocab_size), 2)
......
......@@ -98,13 +98,10 @@ class TFT5ModelTester:
encoder_output = result.encoder_last_hidden_state
self.parent.assertListEqual(list(encoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertListEqual(list(decoder_output.shape), [self.batch_size, self.seq_length, self.hidden_size])
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
self.parent.assertEqual(len(decoder_past), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
self.parent.assertEqual(len(decoder_past[0]), 4)
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5ForConditionalGeneration(config=config)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment