"git@developer.sourcefind.cn:OpenDAS/torch-scatter.git" did not exist on "755144694369460914385326de9adb027f167236"
Unverified Commit 7a9ef818 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: properly handle kwargs in encoder_decoder architectures (#16465)

* properly handle kwargs in encoder_decoder architectures

* make fixup
parent 0540d1b6
......@@ -569,13 +569,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_encoder,
"kwargs_call": {},
}
# Add arguments to encoder from `kwargs_encoder`
for k, v in kwargs_encoder.items():
encoder_processing_inputs[k] = v
kwargs_encoder = {}
encoder_inputs = input_processing(**encoder_processing_inputs)
......@@ -622,13 +621,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
"past_key_values": past_key_values,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_decoder,
"kwargs_call": {},
}
# Add arguments to decoder from `kwargs_decoder`
for k, v in kwargs_decoder.items():
decoder_processing_inputs[k] = v
kwargs_decoder = {}
decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)
......
......@@ -593,12 +593,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"output_hidden_states": output_hidden_states,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_encoder,
"kwargs_call": {},
}
# Add arguments to encoder from `kwargs_encoder`
encoder_processing_inputs.update(kwargs_encoder)
kwargs_encoder = {}
encoder_inputs = input_processing(**encoder_processing_inputs)
......@@ -654,12 +653,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"past_key_values": past_key_values,
"return_dict": return_dict,
"training": training,
"kwargs_call": kwargs_decoder,
"kwargs_call": {},
}
# Add arguments to decoder from `kwargs_decoder`
decoder_processing_inputs.update(kwargs_decoder)
kwargs_decoder = {}
decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs)
......
......@@ -91,6 +91,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
......@@ -122,6 +123,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
......@@ -137,6 +139,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
......@@ -167,6 +170,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)
self.assertEqual(
......@@ -195,6 +199,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
......@@ -208,6 +213,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
......@@ -235,6 +241,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)
# Make sure `loss` exist
......@@ -269,6 +276,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
......
......@@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
......@@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
......@@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin:
encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
self.assertEqual(
......@@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
return_dict=True,
kwargs=kwargs,
)
self.assertEqual(
......@@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0
......@@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
)
out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0
......@@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
labels=labels,
kwargs=kwargs,
)
# Make sure `loss` exist
......@@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
kwargs=kwargs,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment