Unverified Commit 9fef6683 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF - update (vision_)encoder_decoder past variable (#16260)

parent f9387c94
......@@ -647,19 +647,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
past = (encoder_outputs[0], past_key_values) if past_key_values else None
if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
output = tuple([x for x in output if x is not None])
return output
return TFSeq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=past,
past_key_values=past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
......
......@@ -678,19 +678,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
# The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
past = (encoder_outputs[0], past_key_values) if past_key_values else None
if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs
output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
output = tuple([x for x in output if x is not None])
return output
return TFSeq2SeqLMOutput(
loss=loss,
logits=decoder_outputs.logits,
past_key_values=past,
past_key_values=past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment