"git@developer.sourcefind.cn:OpenDAS/fast_rnnt.git" did not exist on "e0bc402953538cc74be6002b3803731af809995c"
Unverified Commit 7a9ef818 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

TF: properly handle kwargs in encoder_decoder architectures (#16465)

* properly handle kwargs in encoder_decoder architectures

* make fixup
parent 0540d1b6
...@@ -569,13 +569,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -569,13 +569,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
"output_hidden_states": output_hidden_states, "output_hidden_states": output_hidden_states,
"return_dict": return_dict, "return_dict": return_dict,
"training": training, "training": training,
"kwargs_call": kwargs_encoder, "kwargs_call": {},
} }
# Add arguments to encoder from `kwargs_encoder` # Add arguments to encoder from `kwargs_encoder`
for k, v in kwargs_encoder.items(): for k, v in kwargs_encoder.items():
encoder_processing_inputs[k] = v encoder_processing_inputs[k] = v
kwargs_encoder = {}
encoder_inputs = input_processing(**encoder_processing_inputs) encoder_inputs = input_processing(**encoder_processing_inputs)
...@@ -622,13 +621,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ...@@ -622,13 +621,12 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss):
"past_key_values": past_key_values, "past_key_values": past_key_values,
"return_dict": return_dict, "return_dict": return_dict,
"training": training, "training": training,
"kwargs_call": kwargs_decoder, "kwargs_call": {},
} }
# Add arguments to decoder from `kwargs_decoder` # Add arguments to decoder from `kwargs_decoder`
for k, v in kwargs_decoder.items(): for k, v in kwargs_decoder.items():
decoder_processing_inputs[k] = v decoder_processing_inputs[k] = v
kwargs_decoder = {}
decoder_inputs = input_processing(**decoder_processing_inputs) decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs) decoder_outputs = self.decoder(**decoder_inputs)
......
...@@ -593,12 +593,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -593,12 +593,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"output_hidden_states": output_hidden_states, "output_hidden_states": output_hidden_states,
"return_dict": return_dict, "return_dict": return_dict,
"training": training, "training": training,
"kwargs_call": kwargs_encoder, "kwargs_call": {},
} }
# Add arguments to encoder from `kwargs_encoder` # Add arguments to encoder from `kwargs_encoder`
encoder_processing_inputs.update(kwargs_encoder) encoder_processing_inputs.update(kwargs_encoder)
kwargs_encoder = {}
encoder_inputs = input_processing(**encoder_processing_inputs) encoder_inputs = input_processing(**encoder_processing_inputs)
...@@ -654,12 +653,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos ...@@ -654,12 +653,11 @@ class TFVisionEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLos
"past_key_values": past_key_values, "past_key_values": past_key_values,
"return_dict": return_dict, "return_dict": return_dict,
"training": training, "training": training,
"kwargs_call": kwargs_decoder, "kwargs_call": {},
} }
# Add arguments to decoder from `kwargs_decoder` # Add arguments to decoder from `kwargs_decoder`
decoder_processing_inputs.update(kwargs_decoder) decoder_processing_inputs.update(kwargs_decoder)
kwargs_decoder = {}
decoder_inputs = input_processing(**decoder_processing_inputs) decoder_inputs = input_processing(**decoder_processing_inputs)
decoder_outputs = self.decoder(**decoder_inputs) decoder_outputs = self.decoder(**decoder_inputs)
......
...@@ -91,6 +91,7 @@ class TFEncoderDecoderMixin: ...@@ -91,6 +91,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
...@@ -122,6 +123,7 @@ class TFEncoderDecoderMixin: ...@@ -122,6 +123,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
...@@ -137,6 +139,7 @@ class TFEncoderDecoderMixin: ...@@ -137,6 +139,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
...@@ -167,6 +170,7 @@ class TFEncoderDecoderMixin: ...@@ -167,6 +170,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True, return_dict=True,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
...@@ -195,6 +199,7 @@ class TFEncoderDecoderMixin: ...@@ -195,6 +199,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
out_2 = np.array(outputs[0]) out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
...@@ -208,6 +213,7 @@ class TFEncoderDecoderMixin: ...@@ -208,6 +213,7 @@ class TFEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
out_1 = np.array(after_outputs[0]) out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0 out_1[np.isnan(out_1)] = 0
...@@ -235,6 +241,7 @@ class TFEncoderDecoderMixin: ...@@ -235,6 +241,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=labels, labels=labels,
kwargs=kwargs,
) )
# Make sure `loss` exist # Make sure `loss` exist
...@@ -269,6 +276,7 @@ class TFEncoderDecoderMixin: ...@@ -269,6 +276,7 @@ class TFEncoderDecoderMixin:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_attentions=True, output_attentions=True,
kwargs=kwargs,
) )
encoder_attentions = outputs_encoder_decoder["encoder_attentions"] encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
......
...@@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -96,6 +96,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values, pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
...@@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -124,6 +125,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values, pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
...@@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -137,6 +139,7 @@ class TFVisionEncoderDecoderMixin:
encoder_outputs=encoder_outputs, encoder_outputs=encoder_outputs,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
...@@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -164,6 +167,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True, return_dict=True,
kwargs=kwargs,
) )
self.assertEqual( self.assertEqual(
...@@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -189,6 +193,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values, pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
out_2 = np.array(outputs[0]) out_2 = np.array(outputs[0])
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
...@@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -201,6 +206,7 @@ class TFVisionEncoderDecoderMixin:
pixel_values=pixel_values, pixel_values=pixel_values,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
kwargs=kwargs,
) )
out_1 = np.array(after_outputs[0]) out_1 = np.array(after_outputs[0])
out_1[np.isnan(out_1)] = 0 out_1[np.isnan(out_1)] = 0
...@@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -226,6 +232,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=labels, labels=labels,
kwargs=kwargs,
) )
# Make sure `loss` exist # Make sure `loss` exist
...@@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin: ...@@ -257,6 +264,7 @@ class TFVisionEncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_attentions=True, output_attentions=True,
kwargs=kwargs,
) )
encoder_attentions = outputs_encoder_decoder["encoder_attentions"] encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment