Unverified Commit 57edd84b authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[whisper] fix multilingual fine-tuning (#30865)

* [whisper] fix multilingual fine-tuning

* config ids as well
parent 977ce58a
...@@ -425,12 +425,8 @@ def main(): ...@@ -425,12 +425,8 @@ def main():
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual: if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We only need to set the language and task ids in a multilingual setting # We only need to set the language and task ids in a multilingual setting
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.update( model.generation_config.language = data_args.language
**{ model.generation_config.task = data_args.task
"language": data_args.language,
"task": data_args.task,
}
)
elif data_args.language is not None: elif data_args.language is not None:
raise ValueError( raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should " "Setting language token for an English-only checkpoint is not permitted. The language argument should "
...@@ -444,6 +440,9 @@ def main(): ...@@ -444,6 +440,9 @@ def main():
"Please use the `language` and `task` arguments instead" "Please use the `language` and `task` arguments instead"
) )
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
else:
model.generation_config.forced_decoder_ids = None
model.config.forced_decoder_ids = None
if model_args.suppress_tokens is not None: if model_args.suppress_tokens is not None:
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment