Unverified Commit 26700a95 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix scheduled tests for `SpeechEncoderDecoderModel` (#13422)

* Add inputs to pretrained tests

* Make style
parent 73ad2588
...@@ -21,7 +21,7 @@ from transformers import is_torch_available ...@@ -21,7 +21,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bert import BertModelTester from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from .test_modeling_speech_to_text import Speech2TextModelTester from .test_modeling_speech_to_text import Speech2TextModelTester
from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester
from .test_modeling_wav2vec2 import Wav2Vec2ModelTester from .test_modeling_wav2vec2 import Wav2Vec2ModelTester
...@@ -50,7 +50,7 @@ class EncoderDecoderMixin: ...@@ -50,7 +50,7 @@ class EncoderDecoderMixin:
def prepare_config_and_inputs(self): def prepare_config_and_inputs(self):
pass pass
def get_pretrained_model(self): def get_pretrained_model_and_inputs(self):
pass pass
def check_encoder_decoder_model_from_pretrained_configs( def check_encoder_decoder_model_from_pretrained_configs(
...@@ -350,17 +350,11 @@ class EncoderDecoderMixin: ...@@ -350,17 +350,11 @@ class EncoderDecoderMixin:
@slow @slow
def test_real_model_save_load_from_pretrained(self): def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model() model_2, inputs = self.get_pretrained_model_and_inputs()
model_2.to(torch_device) model_2.to(torch_device)
input_name, inputs = self.get_inputs()
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
with torch.no_grad(): with torch.no_grad():
outputs = model_2( outputs = model_2(**inputs)
**{input_name: inputs},
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_2 = outputs[0].cpu().numpy() out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0 out_2[np.isnan(out_2)] = 0
...@@ -369,11 +363,7 @@ class EncoderDecoderMixin: ...@@ -369,11 +363,7 @@ class EncoderDecoderMixin:
model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname) model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
model_1.to(torch_device) model_1.to(torch_device)
after_outputs = model_1( after_outputs = model_1(**inputs)
**{input_name: inputs},
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
out_1 = after_outputs[0].cpu().numpy() out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0 out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
...@@ -382,10 +372,23 @@ class EncoderDecoderMixin: ...@@ -382,10 +372,23 @@ class EncoderDecoderMixin:
@require_torch @require_torch
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model_and_inputs(self):
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
"facebook/wav2vec2-base-960h", "bert-base-cased" "facebook/wav2vec2-base-960h", "bert-base-cased"
) )
batch_size = 13
input_values = floats_tensor([batch_size, 512], model.encoder.config.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"input_values": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return model, inputs
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = Wav2Vec2Model(config).eval() encoder_model = Wav2Vec2Model(config).eval()
...@@ -433,10 +436,23 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -433,10 +436,23 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
@require_torch @require_torch
class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase): class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model_and_inputs(self):
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained( model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
"facebook/s2t-small-librispeech-asr", "bert-base-cased" "facebook/s2t-small-librispeech-asr", "bert-base-cased"
) )
batch_size = 13
input_features = floats_tensor([batch_size, 7, 80], model.encoder.config.vocab_size)
attention_mask = random_attention_mask([batch_size, 7])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"input_features": input_features,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return model, inputs
def get_encoder_decoder_model(self, config, decoder_config): def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = Speech2TextEncoder(config).eval() encoder_model = Speech2TextEncoder(config).eval()
...@@ -489,6 +505,10 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -489,6 +505,10 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
def test_save_and_load_from_pretrained(self): def test_save_and_load_from_pretrained(self):
pass pass
# all published pretrained models are Speech2TextModel != Speech2TextEncoder
def test_real_model_save_load_from_pretrained(self):
pass
@require_torch @require_torch
class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase): class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
...@@ -524,5 +544,6 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase): ...@@ -524,5 +544,6 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
} }
def get_pretrained_model(self): # there are no published pretrained Speech2Text2ForCausalLM for now
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large") def test_real_model_save_load_from_pretrained(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment