Unverified Commit bb1d0d0d authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Fix languages covered by M4Tv2 (#28019)



* correct language assessment  + add tests

* Update src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style + simplify and enrich test

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent e2b16485
......@@ -4596,7 +4596,11 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel):
if tgt_lang is not None:
# also accept __xxx__
tgt_lang = tgt_lang.replace("__", "")
for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]:
if generate_speech:
keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]
else:
keys_to_check = ["text_decoder_lang_to_code_id"]
for key in keys_to_check:
lang_code_to_id = getattr(self.generation_config, key, None)
if lang_code_to_id is None:
raise ValueError(
......
......@@ -758,7 +758,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
self.tmpdirname = tempfile.mkdtemp()
def update_generation(self, model):
lang_code_to_id = {
text_lang_code_to_id = {
"fra": 4,
"eng": 4,
"rus": 4,
}
speech_lang_code_to_id = {
"fra": 4,
"eng": 4,
}
......@@ -773,9 +779,9 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
generation_config = copy.deepcopy(model.generation_config)
generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id)
generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id)
generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id)
generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id)
generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("id_to_text", id_to_text)
generation_config.__setattr__("char_to_id", char_to_id)
generation_config.__setattr__("eos_token_id", 0)
......@@ -784,13 +790,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
model.generation_config = generation_config
def prepare_text_input(self):
def prepare_text_input(self, tgt_lang):
config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()
input_dict = {
"input_ids": inputs,
"attention_mask": input_mask,
"tgt_lang": "eng",
"tgt_lang": tgt_lang,
"num_beams": 2,
"do_sample": True,
}
......@@ -837,6 +843,26 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
output = model.generate(**inputs)
return output
def test_generation_languages(self):
config, input_text_rus = self.prepare_text_input(tgt_lang="rus")
model = SeamlessM4Tv2Model(config=config)
self.update_generation(model)
model.to(torch_device)
model.eval()
# make sure that generating speech, with a language that is only supported for text translation, raises error
with self.assertRaises(ValueError):
model.generate(**input_text_rus)
# make sure that generating text only works
model.generate(**input_text_rus, generate_speech=False)
# make sure it works for languages supported by both output modalities
config, input_text_eng = self.prepare_text_input(tgt_lang="eng")
model.generate(**input_text_eng)
model.generate(**input_text_eng, generate_speech=False)
def test_speech_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment