"docs/vscode:/vscode.git/clone" did not exist on "b219d6b5a5d21791dbe5bf19960b8266bb6e9d1d"
Unverified Commit 26700a95 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Fix scheduled tests for `SpeechEncoderDecoderModel` (#13422)

* Add inputs to pretrained tests

* Make style
parent 73ad2588
......@@ -21,7 +21,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor
from .test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
from .test_modeling_speech_to_text import Speech2TextModelTester
from .test_modeling_speech_to_text_2 import Speech2Text2StandaloneDecoderModelTester
from .test_modeling_wav2vec2 import Wav2Vec2ModelTester
......@@ -50,7 +50,7 @@ class EncoderDecoderMixin:
def prepare_config_and_inputs(self):
pass
def get_pretrained_model(self):
def get_pretrained_model_and_inputs(self):
pass
def check_encoder_decoder_model_from_pretrained_configs(
......@@ -350,17 +350,11 @@ class EncoderDecoderMixin:
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()
model_2, inputs = self.get_pretrained_model_and_inputs()
model_2.to(torch_device)
input_name, inputs = self.get_inputs()
decoder_input_ids = ids_tensor([13, 1], model_2.config.encoder.vocab_size)
attention_mask = ids_tensor([13, 5], vocab_size=2)
with torch.no_grad():
outputs = model_2(
**{input_name: inputs},
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
outputs = model_2(**inputs)
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0
......@@ -369,11 +363,7 @@ class EncoderDecoderMixin:
model_1 = SpeechEncoderDecoderModel.from_pretrained(tmp_dirname)
model_1.to(torch_device)
after_outputs = model_1(
**{input_name: inputs},
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
)
after_outputs = model_1(**inputs)
out_1 = after_outputs[0].cpu().numpy()
out_1[np.isnan(out_1)] = 0
max_diff = np.amax(np.abs(out_1 - out_2))
......@@ -382,10 +372,23 @@ class EncoderDecoderMixin:
@require_torch
class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
def get_pretrained_model_and_inputs(self):
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
"facebook/wav2vec2-base-960h", "bert-base-cased"
)
batch_size = 13
input_values = floats_tensor([batch_size, 512], model.encoder.config.vocab_size)
attention_mask = random_attention_mask([batch_size, 512])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"input_values": input_values,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return model, inputs
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = Wav2Vec2Model(config).eval()
......@@ -433,10 +436,23 @@ class Wav2Vec2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
@require_torch
class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self):
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
def get_pretrained_model_and_inputs(self):
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
"facebook/s2t-small-librispeech-asr", "bert-base-cased"
)
batch_size = 13
input_features = floats_tensor([batch_size, 7, 80], model.encoder.config.vocab_size)
attention_mask = random_attention_mask([batch_size, 7])
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
decoder_attention_mask = random_attention_mask([batch_size, 4])
inputs = {
"input_features": input_features,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
}
return model, inputs
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = Speech2TextEncoder(config).eval()
......@@ -489,6 +505,10 @@ class Speech2TextBertModelTest(EncoderDecoderMixin, unittest.TestCase):
def test_save_and_load_from_pretrained(self):
pass
# all published pretrained models are Speech2TextModel != Speech2TextEncoder
def test_real_model_save_load_from_pretrained(self):
pass
@require_torch
class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
......@@ -524,5 +544,6 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
"decoder_attention_mask": decoder_attention_mask,
}
def get_pretrained_model(self):
return SpeechEncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "facebook/bart-large")
# there are no published pretrained Speech2Text2ForCausalLM for now
def test_real_model_save_load_from_pretrained(self):
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment