Unverified Commit 8801861d authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing m4t. (#27240)

* Fixing m4t.

* Trying to remove comparison ? Odd test failure.

* Adding shared. But why on earth does it hang ????

* Putting back the model weights checks the test is silently failing on
cuda.

* Fix style + unremoved comment.
parent 443bf5e9
......@@ -3051,8 +3051,9 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
def __init__(self, config: SeamlessM4TConfig):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.speech_encoder = SeamlessM4TSpeechEncoder(config)
self.text_decoder = SeamlessM4TDecoder(config)
self.text_decoder = SeamlessM4TDecoder(config, self.shared)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
......@@ -3710,8 +3711,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.speech_encoder = SeamlessM4TSpeechEncoder(config)
self.text_decoder = SeamlessM4TDecoder(config)
self.text_decoder = SeamlessM4TDecoder(config, self.shared)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
......
......@@ -863,17 +863,23 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
output_original_text = self.factory_generation_speech_test(model, input_text)
output_original_speech = self.factory_generation_speech_test(model, input_speech)
model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
state_dict = model.state_dict()
text_model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()
output_text = self.factory_generation_speech_test(model, input_text)
model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
speech_model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(speech_model)
speech_model.to(torch_device)
speech_model.eval()
for name, tensor in speech_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
output_speech = self.factory_generation_speech_test(model, input_speech)
......@@ -882,8 +888,15 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
self.assertListEqual(output_original_text[1].ravel().tolist(), output_text[1].ravel().tolist())
# test same speech output from input text
self.assertListEqual(output_original_speech[0].ravel().tolist(), output_speech[0].ravel().tolist())
self.assertListEqual(output_original_speech[1].ravel().tolist(), output_speech[1].ravel().tolist())
# assertTrue because super long list makes this hang in case of failure
self.assertTrue(
output_original_speech[0].ravel().tolist() == output_speech[0].ravel().tolist(),
"Speech generated was different",
)
self.assertTrue(
output_original_speech[1].ravel().tolist() == output_speech[1].ravel().tolist(),
"Speech generated was different",
)
def test_text_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()
......@@ -905,19 +918,30 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
input_speech.pop("generate_speech")
input_text.pop("generate_speech")
model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
state_dict = model.state_dict()
output_text = self.factory_generation_speech_test(model, input_text)
text_model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname)
self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()
model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname)
self.update_generation(model)
model.to(torch_device)
model.eval()
for name, tensor in text_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist())
output_speech = self.factory_generation_speech_test(model, input_speech)
output_text = self.factory_generation_speech_test(text_model, input_text)
speech_model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname)
for name, tensor in speech_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
self.update_generation(speech_model)
speech_model.to(torch_device)
speech_model.eval()
output_speech = self.factory_generation_speech_test(speech_model, input_speech)
# test same text output from input text
self.assertListEqual(output_original_text[0].ravel().tolist(), output_text.ravel().tolist())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment