Unverified Commit 38b53da3 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[examples] update whisper fine-tuning (#29938)

* [examples] update whisper fine-tuning

* deprecate forced/suppress tokens

* item assignment

* update readme

* final fix
parent aafa7ce7
...@@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \ ...@@ -368,6 +368,7 @@ python run_speech_recognition_seq2seq.py \
--dataset_name="mozilla-foundation/common_voice_11_0" \ --dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \ --dataset_config_name="hi" \
--language="hindi" \ --language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \ --train_split_name="train+validation" \
--eval_split_name="test" \ --eval_split_name="test" \
--max_steps="5000" \ --max_steps="5000" \
...@@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \ ...@@ -384,12 +385,10 @@ python run_speech_recognition_seq2seq.py \
--save_steps="1000" \ --save_steps="1000" \
--generation_max_length="225" \ --generation_max_length="225" \
--preprocessing_num_workers="16" \ --preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \ --max_duration_in_seconds="30" \
--text_column_name="sentence" \ --text_column_name="sentence" \
--freeze_feature_encoder="False" \ --freeze_feature_encoder="False" \
--gradient_checkpointing \ --gradient_checkpointing \
--group_by_length \
--fp16 \ --fp16 \
--overwrite_output_dir \ --overwrite_output_dir \
--do_train \ --do_train \
...@@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \ ...@@ -399,7 +398,8 @@ python run_speech_recognition_seq2seq.py \
``` ```
On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**. On a single V100, training should take approximately 8 hours, with a final cross-entropy loss of **1e-4** and word error rate of **32.6%**.
If training on a different language, you should be sure to change the `language` argument. The `language` argument should be omitted for English speech recognition. If training on a different language, you should be sure to change the `language` argument. The `language` and `task`
arguments should be omitted for English speech recognition.
#### Multi GPU Whisper Training #### Multi GPU Whisper Training
The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision: The following example shows how to fine-tune the [Whisper small](https://huggingface.co/openai/whisper-small) checkpoint on the Hindi subset of [Common Voice 11](https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0) using 2 GPU devices in half-precision:
...@@ -410,6 +410,7 @@ torchrun \ ...@@ -410,6 +410,7 @@ torchrun \
--dataset_name="mozilla-foundation/common_voice_11_0" \ --dataset_name="mozilla-foundation/common_voice_11_0" \
--dataset_config_name="hi" \ --dataset_config_name="hi" \
--language="hindi" \ --language="hindi" \
--task="transcribe" \
--train_split_name="train+validation" \ --train_split_name="train+validation" \
--eval_split_name="test" \ --eval_split_name="test" \
--max_steps="5000" \ --max_steps="5000" \
...@@ -425,12 +426,10 @@ torchrun \ ...@@ -425,12 +426,10 @@ torchrun \
--save_steps="1000" \ --save_steps="1000" \
--generation_max_length="225" \ --generation_max_length="225" \
--preprocessing_num_workers="16" \ --preprocessing_num_workers="16" \
--length_column_name="input_length" \
--max_duration_in_seconds="30" \ --max_duration_in_seconds="30" \
--text_column_name="sentence" \ --text_column_name="sentence" \
--freeze_feature_encoder="False" \ --freeze_feature_encoder="False" \
--gradient_checkpointing \ --gradient_checkpointing \
--group_by_length \
--fp16 \ --fp16 \
--overwrite_output_dir \ --overwrite_output_dir \
--do_train \ --do_train \
......
...@@ -119,17 +119,16 @@ class ModelArguments: ...@@ -119,17 +119,16 @@ class ModelArguments:
) )
forced_decoder_ids: List[List[int]] = field( forced_decoder_ids: List[List[int]] = field(
default=None, default=None,
metadata={ metadata={"help": "Deprecated. Please use the `language` and `task` arguments instead."},
)
suppress_tokens: List[int] = field(
default=None, metadata={
"help": ( "help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices " "Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples."
"that will be forced before sampling. For example, [[0, 123]] means the first generated token " "Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly."
"will always be a token of index 123."
) )
}, },
) )
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
apply_spec_augment: bool = field( apply_spec_augment: bool = field(
default=False, default=False,
metadata={ metadata={
...@@ -400,8 +399,6 @@ def main(): ...@@ -400,8 +399,6 @@ def main():
trust_remote_code=model_args.trust_remote_code, trust_remote_code=model_args.trust_remote_code,
) )
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
# SpecAugment for whisper models # SpecAugment for whisper models
if getattr(config, "model_type", None) == "whisper": if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": model_args.apply_spec_augment}) config.update({"apply_spec_augment": model_args.apply_spec_augment})
...@@ -440,9 +437,35 @@ def main(): ...@@ -440,9 +437,35 @@ def main():
model.freeze_encoder() model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False model.model.encoder.gradient_checkpointing = False
if data_args.language is not None: if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting) # We only need to set the language and task ids in a multilingual setting
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
model.generation_config.update(
**{
"language": data_args.language,
"task": data_args.task,
}
)
elif data_args.language is not None:
raise ValueError(
"Setting language token for an English-only checkpoint is not permitted. The language argument should "
"only be set for multilingual checkpoints."
)
# TODO (Sanchit): deprecate these arguments in v4.41
if model_args.forced_decoder_ids is not None:
logger.warning(
"The use of `forced_decoder_ids` is deprecated and will be removed in v4.41."
"Please use the `language` and `task` arguments instead"
)
model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids
if model_args.suppress_tokens is not None:
logger.warning(
"The use of `suppress_tokens` is deprecated and will be removed in v4.41."
"Should you need `suppress_tokens`, please manually set them in the fine-tuning script."
)
model.generation_config.suppress_tokens = model_args.suppress_tokens
# 6. Resample speech dataset if necessary # 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment