Unverified Commit d97d06d0 authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF T5 (#9301)

* Fix T5

* Fix test

* Fix test
parent 83fdd252
......@@ -268,9 +268,9 @@ class TFT5Attention(tf.keras.layers.Layer):
), "past_key_value should have 2 past states: keys and values. Got {} past states".format(
len(past_key_value)
)
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
real_seq_length += shape_list(past_key_value[0])[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
key_length = real_seq_length if key_value_states is None else shape_list(key_value_states)[1]
def shape(hidden_states):
""" projection """
......@@ -1147,13 +1147,14 @@ class TFT5Model(TFT5PreTrainedModel):
training=inputs["training"],
)
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
return decoder_outputs + inputs["encoder_outputs"]
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=past,
......@@ -1332,8 +1333,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if not inputs["return_dict"]:
past = (inputs["encoder_outputs"], decoder_outputs[1]) if inputs["use_cache"] else None
if past is not None:
decoder_outputs = decoder_outputs[:1] + (past,) + decoder_outputs[2:]
output = (logits,) + decoder_outputs[1:] + inputs["encoder_outputs"]
......@@ -1358,6 +1359,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
attentions=attentions,
)
past = (inputs["encoder_outputs"].to_tuple(), decoder_outputs[1]) if inputs["use_cache"] else None
return TFSeq2SeqLMOutput(
loss=loss,
logits=logits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment