Unverified Commit 1609a436 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add MMS CTC Fine-Tuning (#24281)

* Add mms ctc fine tuning

* make style

* More fixes that are needed

* make fix-copies

* make draft for README

* add new file

* move to new file

* make style

* make style

* add quick test

* make style

* make style
parent 0c3fdccf
......@@ -26,6 +26,10 @@ limitations under the License.
- [Librispeech](#librispeech-ctc)
- [Common Voice](#common-voice-ctc)
- [Multilingual Librispeech](#multilingual-librispeech-ctc)
- [Automatic Speech Recognition with CTC and Adapter Layers](#connectionist-temporal-classification-with-adapters)
- [Massive Multilingual Speech (MMS)](#mms-model)
- [Examples](#examples-ctc-adapter)
- [Common Voice](#common-voice-ctc-adapter)
- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence)
- [Whisper Model](#whisper-model)
- [Speech-Encoder-Decoder Model](#warm-started-speech-encoder-decoder-model)
......@@ -243,6 +247,111 @@ they can serve as a baseline to improve upon.
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.13 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft/blob/main/run.sh) |
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.15 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft/blob/main/run.sh) |
## Connectionist Temporal Classification With Adapters
The script [`run_speech_recognition_ctc_adapter.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc_adapter.py) can be used to fine-tune adapter layers for [Wav2Vec2-like models like MMS](https://huggingface.co/docs/transformers/main/en/model_doc/mms) for automatic speech recognition.
### MMS Model
The [Massive Multilingual Speech (MMS) model](https://huggingface.co/facebook/mms-1b-all) has been pre-trained and fine-tuned
on 1000+ languages. The model makes use of adapter attention layers to fine-tune only a small part
of the model on a specific language. The model already comes with fine-tuned adapter layers for 1000+ languages and
can be used for inference for 1000+ languages out of the box.
However, for improved performance or more specific use cases one can re-initialize the adapter weights, freeze all
other weights and fine-tune them on a specific dataset as shown in the [example below](#examples-ctc-adapter).
Note that the adapter weights include low dimensional linear layers for every attention block as well as the final language
model head layers.
### Examples CTC Adapter
In the following we will look at how one can fine-tune adapter weights for any of the
[MMS CTC checkpoints](https://huggingface.co/models?pipeline_tag=automatic-speech-recognition&other=mms&sort=downloads) in less than 1 hour.
#### Common Voice CTC Adapter
As in the examples [above](#examples-ctc), we fine-tune on Common Voice's 6 dataset in Turkish as an example.
Contrary to [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) before there is a `--target_language` which has to be defined to state for which
language or concept the adapter layers shall be trained. The adapter weights will then
accordingly be called `adapter.{<target_language}.safetensors`.
Let's run an example script. Make sure to be logged in so that your model can be directly uploaded to the Hub.
```
huggingface-cli login
```
Now, let's run an example and upload it to the Hub under `wav2vec2-common_voice-tr-mms-demo`.
```sh
python run_speech_recognition_ctc.py \
--dataset_name="common_voice" \
--model_name_or_path="facebook/mms-1b-all" \
--dataset_config_name="tr" \
--output_dir="./wav2vec2-common_voice-tr-mms-demo" \
--num_train_epochs="4" \
--per_device_train_batch_size="32" \
--learning_rate="1e-3" \
--warmup_steps="100" \
--evaluation_strategy="steps" \
--text_column_name="sentence" \
--length_column_name="input_length" \
--save_steps="200" \
--eval_steps="100" \
--save_total_limit="3" \
--target_language="tur" \
--gradient_checkpointing \
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” � \
--fp16 \
--group_by_length \
--do_train --do_eval \
--push_to_hub
```
This should take less than 10 minutes on most GPUs and you should very quickly get word error rates
below 27%.
For an example run, you can have a look at [`patrickvonplaten/wav2vec2-common_voice-tr-mms-demo`](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-mms-demo).
If you'd like to train another adapter model with the same base model, you can simply re-use the same `--output_dir`,
but make sure to pass the `--output_dir` folder also to `--tokenizer_name_or_path` so that the vocabulary is not
overwritten but **extended**. Assuming you would like to train adapter weights on Swedish in addition to Turkish and save
the adapter weights in the same model repo, you can run:
```sh
python run_speech_recognition_ctc.py \
--dataset_name="common_voice" \
--model_name_or_path="facebook/mms-1b-all" \
--dataset_config_name="sw" \
--output_dir="./wav2vec2-common_voice-tr-mms-demo" \
--tokenizer_name_or_path="./wav2vec2-common_voice-tr-mms-demo" \
--num_train_epochs="4" \
--per_device_train_batch_size="32" \
--learning_rate="1e-3" \
--warmup_steps="100" \
--evaluation_strategy="steps" \
--text_column_name="sentence" \
--length_column_name="input_length" \
--save_steps="200" \
--eval_steps="100" \
--save_total_limit="3" \
--target_language="swe" \
--gradient_checkpointing \
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” � \
--fp16 \
--group_by_length \
--do_train --do_eval \
--push_to_hub
```
Now you should have both `adapter.tur.safetensors` and `adapter.swe.safetensors` in the model repo
and you can load the respective language with:
```py
model.load_adapter("tur") # or "swe"
```
respectively.
## Sequence to Sequence
The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/main/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/main/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech
......
......@@ -63,6 +63,7 @@ if SRC_DIRS is not None:
import run_semantic_segmentation
import run_seq2seq_qa as run_squad_seq2seq
import run_speech_recognition_ctc
import run_speech_recognition_ctc_adapter
import run_speech_recognition_seq2seq
import run_summarization
import run_swag
......@@ -446,6 +447,38 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_speech_recognition_ctc_adapter(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_speech_recognition_ctc_adapter.py
--output_dir {tmp_dir}
--model_name_or_path hf-internal-testing/tiny-random-wav2vec2
--dataset_name hf-internal-testing/librispeech_asr_dummy
--dataset_config_name clean
--train_split_name validation
--eval_split_name validation
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 1
--remove_unused_columns False
--overwrite_output_dir True
--preprocessing_num_workers 16
--max_steps 10
--target_language tur
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_speech_recognition_ctc_adapter.main()
result = get_results(tmp_dir)
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "./adapter.tur.safetensors")))
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_speech_recognition_seq2seq(self):
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
......
......@@ -1181,6 +1181,14 @@ class HubertForCTC(HubertPreTrainedModel):
"""
self.hubert.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.hubert.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
......
......@@ -1016,6 +1016,14 @@ class SEWForCTC(SEWPreTrainedModel):
"""
self.sew.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.sew.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(SEW_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
......
......@@ -1556,6 +1556,14 @@ class SEWDForCTC(SEWDPreTrainedModel):
"""
self.sew_d.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.sew_d.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(SEWD_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
......
......@@ -1425,6 +1425,14 @@ class UniSpeechForCTC(UniSpeechPreTrainedModel):
"""
self.unispeech.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.unispeech.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(UNISPEECH_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
......
......@@ -1432,6 +1432,14 @@ class UniSpeechSatForCTC(UniSpeechSatPreTrainedModel):
"""
self.unispeech_sat.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.unispeech_sat.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(UNISPEECH_SAT_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
......
......@@ -1194,6 +1194,19 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
return adapter_weights
def init_adapter_layers(self):
"""
(Re-)initialize attention adapter layers and lm head for adapter-only fine-tuning
"""
# init attention adapters
for module in self.modules():
if isinstance(module, Wav2Vec2AttnAdapterLayer):
self._init_weights(module)
# init lm head
if isinstance(self, Wav2Vec2ForCTC):
self._init_weights(self.lm_head)
def load_adapter(self, target_lang: str, **kwargs):
r"""
Load a language adapter model from a pre-trained adapter model.
......@@ -1888,6 +1901,14 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
"""
self.wav2vec2.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wav2vec2.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
......
......@@ -1319,6 +1319,14 @@ class WavLMForCTC(WavLMPreTrainedModel):
"""
self.wavlm.feature_extractor._freeze_parameters()
def freeze_base_model(self):
"""
Calling this function will disable the gradient computation for the base model so that its parameters will not
be updated during training. Only the classification head will be updated.
"""
for param in self.wavlm.parameters():
param.requires_grad = False
@add_start_docstrings_to_model_forward(WAVLM_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
checkpoint=_CHECKPOINT_FOR_DOC,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment