"git@developer.sourcefind.cn:chenpangpang/ComfyUI.git" did not exist on "a50e1118155d8057dd0364644e8259d55a84bb4a"
Unverified Commit bb1d0d0d authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Fix languages covered by M4Tv2 (#28019)



* correct language assessment  + add tests

* Update src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* make style + simplify and enrich test

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent e2b16485
...@@ -4596,7 +4596,11 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel): ...@@ -4596,7 +4596,11 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel):
if tgt_lang is not None: if tgt_lang is not None:
# also accept __xxx__ # also accept __xxx__
tgt_lang = tgt_lang.replace("__", "") tgt_lang = tgt_lang.replace("__", "")
for key in ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]: if generate_speech:
keys_to_check = ["text_decoder_lang_to_code_id", "t2u_lang_code_to_id", "vocoder_lang_code_to_id"]
else:
keys_to_check = ["text_decoder_lang_to_code_id"]
for key in keys_to_check:
lang_code_to_id = getattr(self.generation_config, key, None) lang_code_to_id = getattr(self.generation_config, key, None)
if lang_code_to_id is None: if lang_code_to_id is None:
raise ValueError( raise ValueError(
......
...@@ -758,7 +758,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): ...@@ -758,7 +758,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
self.tmpdirname = tempfile.mkdtemp() self.tmpdirname = tempfile.mkdtemp()
def update_generation(self, model): def update_generation(self, model):
lang_code_to_id = { text_lang_code_to_id = {
"fra": 4,
"eng": 4,
"rus": 4,
}
speech_lang_code_to_id = {
"fra": 4, "fra": 4,
"eng": 4, "eng": 4,
} }
...@@ -773,9 +779,9 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): ...@@ -773,9 +779,9 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
generation_config = copy.deepcopy(model.generation_config) generation_config = copy.deepcopy(model.generation_config)
generation_config.__setattr__("text_decoder_lang_to_code_id", lang_code_to_id) generation_config.__setattr__("text_decoder_lang_to_code_id", text_lang_code_to_id)
generation_config.__setattr__("t2u_lang_code_to_id", lang_code_to_id) generation_config.__setattr__("t2u_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("vocoder_lang_code_to_id", lang_code_to_id) generation_config.__setattr__("vocoder_lang_code_to_id", speech_lang_code_to_id)
generation_config.__setattr__("id_to_text", id_to_text) generation_config.__setattr__("id_to_text", id_to_text)
generation_config.__setattr__("char_to_id", char_to_id) generation_config.__setattr__("char_to_id", char_to_id)
generation_config.__setattr__("eos_token_id", 0) generation_config.__setattr__("eos_token_id", 0)
...@@ -784,13 +790,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): ...@@ -784,13 +790,13 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
model.generation_config = generation_config model.generation_config = generation_config
def prepare_text_input(self): def prepare_text_input(self, tgt_lang):
config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs() config, inputs, decoder_input_ids, input_mask, lm_labels = self.text_model_tester.prepare_config_and_inputs()
input_dict = { input_dict = {
"input_ids": inputs, "input_ids": inputs,
"attention_mask": input_mask, "attention_mask": input_mask,
"tgt_lang": "eng", "tgt_lang": tgt_lang,
"num_beams": 2, "num_beams": 2,
"do_sample": True, "do_sample": True,
} }
...@@ -837,6 +843,26 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase): ...@@ -837,6 +843,26 @@ class SeamlessM4Tv2GenerationTest(unittest.TestCase):
output = model.generate(**inputs) output = model.generate(**inputs)
return output return output
def test_generation_languages(self):
config, input_text_rus = self.prepare_text_input(tgt_lang="rus")
model = SeamlessM4Tv2Model(config=config)
self.update_generation(model)
model.to(torch_device)
model.eval()
# make sure that generating speech, with a language that is only supported for text translation, raises error
with self.assertRaises(ValueError):
model.generate(**input_text_rus)
# make sure that generating text only works
model.generate(**input_text_rus, generate_speech=False)
# make sure it works for languages supported by both output modalities
config, input_text_eng = self.prepare_text_input(tgt_lang="eng")
model.generate(**input_text_eng)
model.generate(**input_text_eng, generate_speech=False)
def test_speech_generation(self): def test_speech_generation(self):
config, input_speech, input_text = self.prepare_speech_and_text_input() config, input_speech, input_text = self.prepare_speech_and_text_input()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment