"tests/vscode:/vscode.git/clone" did not exist on "4c9e0f029e55d9f22d1c119d4be018a3e552b0a0"
Unverified Commit 1c121916 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Add Speech Seq2Seq Training script (#14792)

* start

* add gradient checkpointing and feature extractor freezing

* Apply suggestions from code review

* up

* up

* up

* correct

* up

* more changes

* up

* up

* up

* remove rst
parent 10fd4fa1
......@@ -14,12 +14,27 @@ See the License for the specific language governing permissions and
limitations under the License.
-->
# Automatic Speech Recognition examples
## Connectionist Temporal Classification without Language Model (CTC w/o LM)
The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/transformers/master/model_doc/auto.html?highlight=automodelforctc#automodelforctc) for automatic speech
# Automatic Speech Recognition Examples
## Table of Contents
- [Automatic Speech Recognition with CTC](#connectionist-temporal-classification)
- [Single GPU example](#single-gpu)
- [Multi GPU example](#multi-gpu)
- [Examples](#examples)
- [TIMIT](#timit)
- [Librispeech](#librispeech)
- [Common Voice](#common-voice)
- [Multilingual Librispeech](#multilingual-librispeech)
- [Automatic Speech Recognition with Sequence-to-Sequence](#sequence-to-sequence)
- [Single GPU example](#single-gpu)
- [Multi GPU example](#multi-gpu)
- [Examples](#examples)
- [Librispeech](#librispeech)
## Connectionist Temporal Classification
The script [`run_speech_recognition_ctc.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py) can be used to fine-tune any pretrained [Connectionist Temporal Classification Model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForCTC) for automatic speech
recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset.
Speech recognition models that have been pretrained in unsupervised fashion on audio data alone, *e.g.* [Wav2Vec2](https://huggingface.co/transformers/master/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/master/model_doc/hubert.html), [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html), have shown to require only
......@@ -41,7 +56,7 @@ If the environment variable is not set, the training script might freeze, *i.e.*
---
### Single-GPU
### Single GPU
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using a single GPU in half-precision.
......@@ -75,7 +90,7 @@ python run_speech_recognition_ctc.py \
On a single V100 GPU, this script should run in *ca.* 1 hour 20 minutes and yield a CTC loss of **0.39** and word error rate
of **0.35**.
### Multi-GPU
### Multi GPU
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using 8 GPUs in half-precision.
......@@ -92,7 +107,6 @@ python -m torch.distributed.launch \
--learning_rate="3e-4" \
--warmup_steps="500" \
--evaluation_strategy="steps" \
--audio_column_name="path" \
--text_column_name="sentence" \
--save_steps="400" \
--eval_steps="100" \
......@@ -118,6 +132,8 @@ The presented performances are by no means optimal as no hyper-parameter tuning
they can serve as a baseline to improve upon.
#### TIMIT
- [TIMIT](https://huggingface.co/datasets/timit_asr)
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
......@@ -129,6 +145,7 @@ they can serve as a baseline to improve upon.
| [TIMIT](https://huggingface.co/datasets/timit_asr)| - | [ntu-spml/distilhubert](https://huggingface.co/ntu-spml/distilhubert) | 0.68 | - | 1 GPU TITAN RTX | 26min | [here](https://huggingface.co/patrickvonplaten/distilhubert-timit) | [run.sh](https://huggingface.co/patrickvonplaten/distilhubert-timit/blob/main/run.sh) |
#### Librispeech
- [Librispeech](https://huggingface.co/datasets/librispeech_asr)
......@@ -139,7 +156,10 @@ they can serve as a baseline to improve upon.
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) | 0.042 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist/blob/main/run.sh) |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) | 0.042 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-librispeech-clean-100h-demo-dist/blob/main/run.sh) |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/hubert-large-ll60k](https://huggingface.co/facebook/hubert-large-ll60k) | 0.088 | - | 8 GPU V100 | 1h30min | [here](https://huggingface.co/patrickvonplaten/hubert-librispeech-clean-100h-demo-dist) | [run.sh](https://huggingface.co/patrickvonplaten/hubert-librispeech-clean-100h-demo-dist/blob/main/run.sh) |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 0.167 | | | 8 GPU V100 | 54min | [here](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft) | [run.sh](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft/blob/main/run.sh) |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [asapp/sew-mid-100k](https://huggingface.co/asapp/sew-mid-100k) | 0.167 | | 8 GPU V100 | 54min | [here](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft) | [run.sh](https://huggingface.co/patrickvonplaten/sew-mid-100k-librispeech-clean-100h-ft/blob/main/run.sh) |
#### Common Voice
- [Common Voice](https://huggingface.co/datasets/common_voice)
......@@ -154,9 +174,196 @@ they can serve as a baseline to improve upon.
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.31 | - | 8 GPU V100 | 1h05 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft/blob/main/run.sh) |
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | 0.21 | - | 2 GPU Titan 24 GB RAM | 15h10 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xls-r-1b-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-1b-common_voice-tr-ft/blob/main/run.sh) |
#### Multilingual Librispeech
- [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
|-------|------------------------------|-------------|---------------|---------------|----------------------|-------------| -------------| ------- |
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.13 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-xlsr-53-300m-mls-german-ft/blob/main/run.sh) |
| [Multilingual Librispeech](https://huggingface.co/datasets/multilingual_librispeech)| `"german"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.15 | - | 1 GPU Titan 24 GB RAM | 15h04 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-300m-mls-german-ft/blob/main/run.sh) |
## Sequence to Sequence
The script [`run_speech_recognition_seq2seq.py`](https://github.com/huggingface/transformers/blob/master/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) can be used to fine-tune any [Speech Sequence-to-Sequence Model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModelForSpeechSeq2Seq) for automatic speech
recognition on one of the [official speech recognition datasets](https://huggingface.co/datasets?task_ids=task_ids:automatic-speech-recognition) or a custom dataset.
A very common use case is to leverage a pretrained speech [encoding model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModel),
*e.g.* [Wav2Vec2](https://huggingface.co/transformers/master/model_doc/wav2vec2.html), [HuBERT](https://huggingface.co/transformers/master/model_doc/hubert.html), [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) with a pretrained [text decoding model](https://huggingface.co/docs/transformers/master/en/model_doc/auto#transformers.AutoModel), *e.g.* [Bart](https://huggingface.co/docs/transformers/master/en/model_doc/bart#transformers.BartForCausalLM) to create a [SpeechEnocderDecoderModel](https://huggingface.co/docs/transformers/master/en/model_doc/speechencoderdecoder#speech-encoder-decoder-models).
Consequently, the warm-started Speech-Encoder-Decoder model can be fine-tuned in
this script.
As an example, let's instantiate a *Wav2Vec2-2-Bart* model with the `SpeechEnocderDecoderModel` framework:
First create an empty repo on `hf.co`:
```bash
huggingface-cli repo create wav2vec2-2-bart-base
git clone https://huggingface.co/<your-user-name>/wav2vec2-2-bart-base
cd wav2vec2-2-bart-base
```
Next, run the following script **inside** the just cloned repo:
```py
from transformers import SpeechEncoderDecoderModel, AutoFeatureExtractor, AutoTokenizer, Wav2Vec2Processor
# checkpoints to leverage
encoder_id = "facebook/wav2vec2-base"
decoder_id = "facebook/bart-base"
# load and save speech-encoder-decoder model
# set some hyper-parameters for training and evaluation
model = SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(encoder_id, decoder_id, encoder_add_adapter=True, encoder_feat_proj_dropout=0.0, encoder_layerdrop=0.0, max_length=200, num_beams=5)
model.config.decoder_start_token_id = model.decoder.config.bos_token_id
model.config.pad_token_id = model.decoder.config.pad_token_id
model.config.eos_token_id = model.decoder.config.eos_token_id
model.save_pretrained("./")
# load and save processor
feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_id)
tokenizer = AutoTokenizer.from_pretrained(decoder_id)
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
processor.save_pretrained("./")
```
Finally, we can upload all files:
```bash
git lfs install
git add . && git commit -m "upload model files" && git push
```
and link the official `run_speech_recognition_seq2seq.py` script to the folder:
```bash
ln -s $(realpath <path/to/transformers>/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py) ./
```
Note that we have added a randomly initialized adapter to `wav2vec2-base` with
`encoder_add_adapter=True` which further samples the output sequence of
`wav2vec2-base` along the time dimension. The reason is that by default a single
output vector of `wav2vec2-base` has a receptive field of *ca.* 25ms (*cf.* with
section *4.2* of the [official Wav2Vec2 paper](https://arxiv.org/pdf/2006.11477.pdf)), which represents a little less a single character. BART on the other hand
makes use of a sentence-piece tokenizer as an input processor so that a single
hidden vector of `bart-base` represents *ca.* 4 characters. To better align
the output of *Wav2Vec2* and *BART*'s hidden vectors for the cross-attention
mechanism, we further subsample *Wav2Vec2*'s output by a factor of 8 by
adding a convolution-based adapter.
Having warm-started the speech-encoder-decoder model `<your-user-name>/wav2vec2-2-bart`, we can now fine-tune it on speech recognition.
In the script [`run_speech_recognition_seq2seq`], we load the warm-started model,
the feature extractor, and the tokenizer, process a speech recognition dataset,
and then make use of the [`Seq2SeqTrainer`](https://huggingface.co/docs/transformers/master/en/main_classes/trainer#transformers.Seq2SeqTrainer).
Note that it is important to also align the decoder's vocabulary with
the speech transcriptions of the dataset. *E.g.* the [`Librispeech`](https://huggingface.co/datasets/librispeech_asr) has only captilized letters in the transcriptions,
whereas BART was pretrained mostly on normalized text. Thus it is recommended to add
`--do_lower_case` to the fine-tuning script when using a warm-started `SpeechEncoderDecoderModel`. The model is fine-tuned on the standard cross-entropy language modeling
loss for sequence-to-sequence (just like *T5* or *BART* in natural language processing).
---
**NOTE**
If you encounter problems with data preprocessing by setting `--preprocessing_num_workers` > 1,
you might want to set the environment variable `OMP_NUM_THREADS` to 1 as follows:
```bash
OMP_NUM_THREADS=1 python run_speech_recognition_ctc ...
```
If the environment variable is not set, the training script might freeze, *i.e.* see: https://github.com/pytorch/audio/issues/1021#issuecomment-726915239
---
### Single GPU
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using a single GPU in half-precision.
```bash
python run_speech_recognition_seq2seq.py \
--nproc_per_node 8 run_speech_recognition_seq2seq.py \
--dataset_name="librispeech_asr" \
--model_name_or_path="./" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--eval_split_name="validation" \
--output_dir="./" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--overwrite_output_dir \
--num_train_epochs="5" \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="8" \
--gradient_accumulation_steps="8" \
--learning_rate="3e-4" \
--warmup_steps="400" \
--evaluation_strategy="steps" \
--text_column_name="text" \
--save_steps="400" \
--eval_steps="400" \
--logging_steps="10" \
--save_total_limit="1" \
--freeze_feature_extractor \
--gradient_checkpointing \
--fp16 \
--group_by_length \
--predict_with_generate \
--generation_max_length="40" \
--generation_num_beams="1" \
--do_train --do_eval \
--do_lower_case
```
On a single V100 GPU, this script should run in *ca.* 5 hours and yield a
cross-entropy loss of **0.405** and word error rate of **0.0728**.
### Multi GPU
The following command shows how to fine-tune [XLSR-Wav2Vec2](https://huggingface.co/transformers/master/model_doc/xlsr_wav2vec2.html) on [Common Voice](https://huggingface.co/datasets/common_voice) using 8 GPUs in half-precision.
```bash
python -m torch.distributed.launch \
--nproc_per_node 8 run_speech_recognition_seq2seq.py \
--dataset_name="librispeech_asr" \
--model_name_or_path="./" \
--dataset_config_name="clean" \
--train_split_name="train.100" \
--eval_split_name="validation" \
--output_dir="./" \
--preprocessing_num_workers="16" \
--length_column_name="input_length" \
--overwrite_output_dir \
--num_train_epochs="5" \
--per_device_train_batch_size="8" \
--per_device_eval_batch_size="8" \
--gradient_accumulation_steps="1" \
--learning_rate="3e-4" \
--warmup_steps="400" \
--evaluation_strategy="steps" \
--text_column_name="text" \
--save_steps="400" \
--eval_steps="400" \
--logging_steps="10" \
--save_total_limit="1" \
--freeze_feature_extractor \
--gradient_checkpointing \
--fp16 \
--group_by_length \
--predict_with_generate \
--do_train --do_eval \
--do_lower_case
```
On 8 V100 GPUs, this script should run in *ca.* 45 minutes and yield a cross-entropy loss of **0.405** and word error rate of **0.0728**
### Examples
#### Librispeech
- [Librispeech](https://huggingface.co/datasets/librispeech_asr)
| Dataset | Dataset Config | Pretrained Model | Word error rate on eval | Phoneme error rate on eval | GPU setup | Training time | Fine-tuned Model & Logs | Command to reproduce |
|-------|------------------------------|-------------|---------------|---------------|----------------------|-------------| -------------| ------- |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base) and [facebook/bart-base](https://huggingface.co/facebook/bart-base) | 0.0728 | - | 8 GPU V100 | 45min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-base/blob/main/run_librispeech.sh) |
| [Librispeech](https://huggingface.co/datasets/librispeech_asr)| `"clean"` - `"train.100"` | [facebook/wav2vec2-large-lv60](https://huggingface.co/facebook/wav2vec2-large-lv60) and [facebook/bart-large](https://huggingface.co/facebook/bart-large) | 0.0486 | - | 8 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large) | [create_model.py](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/create_model.py) & [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-2-bart-large/blob/main/run_librispeech.sh) |
......@@ -635,14 +635,13 @@ def main():
return metrics
# Now create a single processor
# Now save everything to be able to create a single processor later
if is_main_process(training_args.local_rank):
# save feature extractor, tokenizer and config
feature_extractor.save_pretrained(training_args.output_dir)
tokenizer.save_pretrained(training_args.output_dir)
config.save_pretrained(training_args.output_dir)
# load processor
try:
processor = AutoProcessor.from_pretrained(training_args.output_dir)
except (OSError, KeyError):
......
......@@ -59,6 +59,7 @@ if SRC_DIRS is not None:
import run_qa as run_squad
import run_seq2seq_qa as run_squad_seq2seq
import run_speech_recognition_ctc
import run_speech_recognition_seq2seq
import run_summarization
import run_swag
import run_translation
......@@ -473,6 +474,39 @@ class ExamplesTests(TestCasePlus):
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_speech_recognition_seq2seq(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_speech_recognition_seq2seq.py
--output_dir {tmp_dir}
--model_name_or_path hf-internal-testing/tiny-random-speech-encoder-decoder
--dataset_name hf-internal-testing/librispeech_asr_dummy
--dataset_config_name clean
--train_split_name validation
--eval_split_name validation
--do_train
--do_eval
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 4
--remove_unused_columns False
--overwrite_output_dir True
--preprocessing_num_workers 16
--max_steps 10
--seed 42
""".split()
if is_cuda_and_apex_available():
testargs.append("--fp16")
with patch.object(sys, "argv", testargs):
run_speech_recognition_seq2seq.main()
result = get_results(tmp_dir)
self.assertLess(result["eval_loss"], result["train_loss"])
def test_run_audio_classification(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
......@@ -521,10 +555,10 @@ class ExamplesTests(TestCasePlus):
--dataset_config_names clean
--dataset_split_names validation
--learning_rate 1e-4
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--preprocessing_num_workers 16
--max_train_steps 5
--max_train_steps 2
--validation_split_percentage 5
--seed 42
""".split()
......
......@@ -164,7 +164,7 @@ class AutoProcessor:
model_type = config_class_to_model_type(type(config).__name__)
if getattr(config, "processor_class", None) is not None:
processor_class = config.processor_class
processor_class = processor_class_from_name(config.processor_class)
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
model_type = config_class_to_model_type(type(config).__name__)
......
......@@ -905,6 +905,7 @@ class HubertModel(HubertPreTrainedModel):
self.feature_extractor = HubertFeatureExtractor(config)
self.feature_projection = HubertFeatureProjection(config)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
......
......@@ -805,6 +805,7 @@ class SEWModel(SEWPreTrainedModel):
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
self.encoder = SEWEncoder(config)
......
......@@ -1341,6 +1341,7 @@ class SEWDModel(SEWDPreTrainedModel):
self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
self.encoder = SEWDEncoder(config)
......
......@@ -181,6 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
config_class = SpeechEncoderDecoderConfig
base_model_prefix = "speech_encoder_decoder"
main_input_name = "inputs"
supports_gradient_checkpointing = True
def __init__(
self,
......@@ -247,6 +248,11 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
)
def _set_gradient_checkpointing(self, module, value=False):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
def get_encoder(self):
return self.encoder
......@@ -259,6 +265,13 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
def set_output_embeddings(self, new_embeddings):
return self.decoder.set_output_embeddings(new_embeddings)
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor of the speech encoder so
that its parameters will not be updated during training.
"""
self.encoder.freeze_feature_extractor()
@classmethod
def from_pretrained(cls, *args, **kwargs):
# At the moment fast initialization is not supported for composite models
......@@ -367,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path, **kwargs_encoder)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
......@@ -378,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
kwargs_encoder["config"] = encoder_config
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
......@@ -389,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
......@@ -411,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path)
# instantiate config with corresponding kwargs
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
......
......@@ -1052,6 +1052,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
self.feature_extractor = UniSpeechFeatureExtractor(config)
self.feature_projection = UniSpeechFeatureProjection(config)
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
......
......@@ -1197,6 +1197,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
self.feature_projection = Wav2Vec2FeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
......@@ -1209,6 +1211,13 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.feature_extractor._freeze_parameters()
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
......
......@@ -19,6 +19,7 @@ import warnings
from contextlib import contextmanager
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_fast import PreTrainedTokenizerFast
from ..auto.tokenization_auto import AutoTokenizer
from .feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
from .tokenization_wav2vec2 import Wav2Vec2CTCTokenizer
......@@ -44,7 +45,7 @@ class Wav2Vec2Processor:
raise ValueError(
f"`feature_extractor` has to be of type {Wav2Vec2FeatureExtractor.__class__}, but is {type(feature_extractor)}"
)
if not isinstance(tokenizer, PreTrainedTokenizer):
if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
raise ValueError(
f"`tokenizer` has to be of type {PreTrainedTokenizer.__class__}, but is {type(tokenizer)}"
)
......
......@@ -1149,6 +1149,8 @@ class WavLMModel(WavLMPreTrainedModel):
self.feature_extractor = WavLMFeatureExtractor(config)
self.feature_projection = WavLMFeatureProjection(config)
# model only needs masking vector if mask prob is > 0.0
if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
......@@ -1161,6 +1163,13 @@ class WavLMModel(WavLMPreTrainedModel):
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameters
will not be updated during training.
"""
self.feature_extractor._freeze_parameters()
def _mask_hidden_states(
self,
hidden_states: torch.FloatTensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment