Unverified Commit 60ba4820 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

fix bug in PT speech-encoder-decoder (#15699)



* fix bug in PT speech-encoder-decoder

* add pt test for `inputs is not None`

* fix test

* new pt test

* Update tests/test_modeling_speech_encoder_decoder.py

* make fixup
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 3de12906
...@@ -490,15 +490,16 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -490,15 +490,16 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_") argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
} }
if encoder_outputs is None and inputs is None: if encoder_outputs is None:
if input_values is not None and input_features is not None: if inputs is None:
raise ValueError("You cannot specify both input_values and input_features at the same time") if input_values is not None and input_features is not None:
elif input_values is not None: raise ValueError("You cannot specify both input_values and input_features at the same time")
inputs = input_values elif input_values is not None:
elif input_features is not None: inputs = input_values
inputs = input_features elif input_features is not None:
else: inputs = input_features
raise ValueError("You have to specify either input_values or input_features") else:
raise ValueError("You have to specify either input_values or input_features")
encoder_outputs = self.encoder( encoder_outputs = self.encoder(
inputs, inputs,
......
...@@ -125,6 +125,43 @@ class EncoderDecoderMixin: ...@@ -125,6 +125,43 @@ class EncoderDecoderMixin:
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
) )
def check_encoder_decoder_model_with_inputs(
self,
config,
attention_mask,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
input_values=None,
input_features=None,
**kwargs
):
inputs = input_values if input_features is None else input_features
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
inputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
outputs_encoder_decoder_kwarg = enc_dec_model(
inputs=inputs,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
)
self.assertEqual(
outputs_encoder_decoder_kwarg["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_encoder_decoder_model_from_pretrained( def check_encoder_decoder_model_from_pretrained(
self, self,
config, config,
...@@ -325,6 +362,10 @@ class EncoderDecoderMixin: ...@@ -325,6 +362,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model(**input_ids_dict) self.check_encoder_decoder_model(**input_ids_dict)
def test_encoder_decoder_model_with_inputs(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_with_inputs(**input_ids_dict)
def test_encoder_decoder_model_from_pretrained_configs(self): def test_encoder_decoder_model_from_pretrained_configs(self):
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict) self.check_encoder_decoder_model_from_pretrained_configs(**input_ids_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment