"vscode:/vscode.git/clone" did not exist on "d28b7aa8cb2e1d10ec5acc5e214faf9525a64a46"
Unverified Commit 17083b9b authored by Connor Henderson's avatar Connor Henderson Committed by GitHub
Browse files

fix: Passing language as acronym to Whisper generate (#23141)

* add fix

* address comments

* remove error formatting
parent 40082d59
...@@ -1562,6 +1562,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1562,6 +1562,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
generation_config.return_timestamps = False generation_config.return_timestamps = False
if language is not None: if language is not None:
language = language.lower()
generation_config.language = language generation_config.language = language
if task is not None: if task is not None:
generation_config.task = task generation_config.task = task
...@@ -1573,10 +1574,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): ...@@ -1573,10 +1574,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel):
language_token = generation_config.language language_token = generation_config.language
elif generation_config.language in TO_LANGUAGE_CODE.keys(): elif generation_config.language in TO_LANGUAGE_CODE.keys():
language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>"
elif generation_config.language in TO_LANGUAGE_CODE.values():
language_token = f"<|{generation_config.language}|>"
else: else:
is_language_code = len(generation_config.language) == 2
raise ValueError( raise ValueError(
f"Unsupported language: {self.language}. Language should be one of:" f"Unsupported language: {generation_config.language}. Language should be one of:"
f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}." f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}."
) )
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
else: else:
......
...@@ -414,6 +414,21 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi ...@@ -414,6 +414,21 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
model.generate(input_features) model.generate(input_features)
model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3)
def test_generate_language(self):
config, input_dict = self.model_tester.prepare_config_and_inputs()
input_features = input_dict["input_features"]
model = WhisperForConditionalGeneration(config).to(torch_device)
# Hack to keep the test fast and not require downloading a model with a generation_config
model.generation_config.__setattr__("lang_to_id", {"<|en|>": 1})
model.generation_config.__setattr__("task_to_id", {"transcribe": 2})
# test language code
model.generate(input_features, language="en")
# test tokenizer code
model.generate(input_features, language="<|en|>")
# test language name
model.generate(input_features, language="English")
def test_forward_signature(self): def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common() config, _ = self.model_tester.prepare_config_and_inputs_for_common()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment