"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "7c44c864a5c93372b9043629aaa9634dabcaea9f"
Unverified Commit 9fef6683 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF - update (vision_)encoder_decoder past variable (#16260)

parent f9387c94
...@@ -647,19 +647,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -647,19 +647,17 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
# The starting index of the remaining elements in `decoder_outputs` # The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
past = (encoder_outputs[0], past_key_values) if past_key_values else None
if not decoder_inputs["return_dict"]: if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple): if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple() encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
output = tuple([x for x in output if x is not None]) output = tuple([x for x in output if x is not None])
return output return output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=past, past_key_values=past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
......
...@@ -678,19 +678,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -678,19 +678,17 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
# The starting index of the remaining elements in `decoder_outputs` # The starting index of the remaining elements in `decoder_outputs`
start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)]) start_index = sum([1 if x is not None else 0 for x in (loss, logits, past_key_values)])
past = (encoder_outputs[0], past_key_values) if past_key_values else None
if not decoder_inputs["return_dict"]: if not decoder_inputs["return_dict"]:
if not isinstance(encoder_outputs, tuple): if not isinstance(encoder_outputs, tuple):
encoder_outputs = encoder_outputs.to_tuple() encoder_outputs = encoder_outputs.to_tuple()
output = (loss, logits, past) + decoder_outputs[start_index:] + encoder_outputs output = (loss, logits, past_key_values) + decoder_outputs[start_index:] + encoder_outputs
output = tuple([x for x in output if x is not None]) output = tuple([x for x in output if x is not None])
return output return output
return TFSeq2SeqLMOutput( return TFSeq2SeqLMOutput(
loss=loss, loss=loss,
logits=decoder_outputs.logits, logits=decoder_outputs.logits,
past_key_values=past, past_key_values=past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment