Unverified Commit 8801861d authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Fixing m4t. (#27240)

* Fixing m4t.

* Trying to remove comparison ? Odd test failure.

* Adding shared. But why on earth does it hang ????

* Putting back the model weights checks the test is silently failing on
cuda.

* Fix style + unremoved comment.
parent 443bf5e9
...@@ -3051,8 +3051,9 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel): ...@@ -3051,8 +3051,9 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel):
def __init__(self, config: SeamlessM4TConfig): def __init__(self, config: SeamlessM4TConfig):
super().__init__(config) super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.speech_encoder = SeamlessM4TSpeechEncoder(config) self.speech_encoder = SeamlessM4TSpeechEncoder(config)
self.text_decoder = SeamlessM4TDecoder(config) self.text_decoder = SeamlessM4TDecoder(config, self.shared)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing # Initialize weights and apply final processing
...@@ -3710,8 +3711,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel): ...@@ -3710,8 +3711,9 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.shared = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id)
self.speech_encoder = SeamlessM4TSpeechEncoder(config) self.speech_encoder = SeamlessM4TSpeechEncoder(config)
self.text_decoder = SeamlessM4TDecoder(config) self.text_decoder = SeamlessM4TDecoder(config, self.shared)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing # Initialize weights and apply final processing
......
...@@ -863,17 +863,23 @@ class SeamlessM4TGenerationTest(unittest.TestCase): ...@@ -863,17 +863,23 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
output_original_text = self.factory_generation_speech_test(model, input_text) output_original_text = self.factory_generation_speech_test(model, input_text)
output_original_speech = self.factory_generation_speech_test(model, input_speech) output_original_speech = self.factory_generation_speech_test(model, input_speech)
model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname) state_dict = model.state_dict()
self.update_generation(model)
model.to(torch_device) text_model = SeamlessM4TForTextToSpeech.from_pretrained(self.tmpdirname)
model.eval() self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()
output_text = self.factory_generation_speech_test(model, input_text) output_text = self.factory_generation_speech_test(model, input_text)
model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname) speech_model = SeamlessM4TForSpeechToSpeech.from_pretrained(self.tmpdirname)
self.update_generation(model) self.update_generation(speech_model)
model.to(torch_device) speech_model.to(torch_device)
model.eval() speech_model.eval()
for name, tensor in speech_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
output_speech = self.factory_generation_speech_test(model, input_speech) output_speech = self.factory_generation_speech_test(model, input_speech)
...@@ -882,8 +888,15 @@ class SeamlessM4TGenerationTest(unittest.TestCase): ...@@ -882,8 +888,15 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
self.assertListEqual(output_original_text[1].ravel().tolist(), output_text[1].ravel().tolist()) self.assertListEqual(output_original_text[1].ravel().tolist(), output_text[1].ravel().tolist())
# test same speech output from input text # test same speech output from input text
self.assertListEqual(output_original_speech[0].ravel().tolist(), output_speech[0].ravel().tolist()) # assertTrue because super long list makes this hang in case of failure
self.assertListEqual(output_original_speech[1].ravel().tolist(), output_speech[1].ravel().tolist()) self.assertTrue(
output_original_speech[0].ravel().tolist() == output_speech[0].ravel().tolist(),
"Speech generated was different",
)
self.assertTrue(
output_original_speech[1].ravel().tolist() == output_speech[1].ravel().tolist(),
"Speech generated was different",
)
def test_text_generation(self): def test_text_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input() config, input_speech, input_text = self.prepare_speech_and_text_input()
...@@ -905,19 +918,30 @@ class SeamlessM4TGenerationTest(unittest.TestCase): ...@@ -905,19 +918,30 @@ class SeamlessM4TGenerationTest(unittest.TestCase):
input_speech.pop("generate_speech") input_speech.pop("generate_speech")
input_text.pop("generate_speech") input_text.pop("generate_speech")
model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname) state_dict = model.state_dict()
self.update_generation(model)
model.to(torch_device)
model.eval()
output_text = self.factory_generation_speech_test(model, input_text) text_model = SeamlessM4TForTextToText.from_pretrained(self.tmpdirname)
self.update_generation(text_model)
text_model.to(torch_device)
text_model.eval()
model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname) for name, tensor in text_model.state_dict().items():
self.update_generation(model) right_tensor = state_dict.get(name)
model.to(torch_device) self.assertEqual(tensor.tolist(), right_tensor.tolist())
model.eval()
output_speech = self.factory_generation_speech_test(model, input_speech) output_text = self.factory_generation_speech_test(text_model, input_text)
speech_model = SeamlessM4TForSpeechToText.from_pretrained(self.tmpdirname)
for name, tensor in speech_model.state_dict().items():
right_tensor = state_dict.get(name)
self.assertEqual(tensor.tolist(), right_tensor.tolist(), f"Tensor {name}")
self.update_generation(speech_model)
speech_model.to(torch_device)
speech_model.eval()
output_speech = self.factory_generation_speech_test(speech_model, input_speech)
# test same text output from input text # test same text output from input text
self.assertListEqual(output_original_text[0].ravel().tolist(), output_text.ravel().tolist()) self.assertListEqual(output_original_text[0].ravel().tolist(), output_text.ravel().tolist())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment