Unverified Commit d97d06d0 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF T5 (#9301)

* Fix T5

* Fix test

* Fix test
parent 83fdd252
...@@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer): ...@@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer):
), "past_key_value should have 2 past states: keys and values. Got {} past states".format( ), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value) len(past_key_value)
) )
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1]
def shape(hidden_states): def shape(hidden_states):
""" projection """ """ projection """
...@@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel):
training=inputs["training"], training=inputs["training"],
) )
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"] return decoder_outputs + inputs["encoder_outputs"]
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqModelOutput( return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state, last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=past, past_key_values=past,
...@@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if not inputs["return_dict"]: if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None: if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:] decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"] output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
...@@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attentions=attentions, attentions=attentions,
) )
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=loss, loss=loss,
logits=logits, logits=logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment