Unverified Commit a459f7f9 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

Add ASR CTC streaming example (#15309)



* Single-epoch run

* Apply suggestions from code review
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Infinite dataset

* Trainer fix + distributed benchmark

* Benchmark fix

* unused import

* interleaved splits

* interleaved splits

* has_length util

* Move to research projects

* Leftover Sized checks

* Bump min version

* Unused import

* Revert trainer changes
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 75b13f82
...@@ -127,6 +127,62 @@ python -m torch.distributed.launch \ ...@@ -127,6 +127,62 @@ python -m torch.distributed.launch \
On 8 V100 GPUs, this script should run in *ca.* 18 minutes and yield a CTC loss of **0.39** and word error rate On 8 V100 GPUs, this script should run in *ca.* 18 minutes and yield a CTC loss of **0.39** and word error rate
of **0.36**. of **0.36**.
### Multi GPU CTC with Dataset Streaming
The following command shows how to use [Dataset Streaming mode](https://huggingface.co/docs/datasets/dataset_streaming.html)
to fine-tune [XLS-R](https://huggingface.co/transformers/master/model_doc/xls_r.html)
on [Common Voice](https://huggingface.co/datasets/common_voice) using 4 GPUs in half-precision.
Streaming mode imposes several constraints on training:
1. We need to construct a tokenizer beforehand and define it via `--tokenizer_name_or_path`.
2. `--num_train_epochs` has to be replaced by `--max_steps`. Similarly, all other epoch-based arguments have to be
replaced by step-based ones.
3. Full dataset shuffling on each epoch is not possible, since we don't have the whole dataset available at once.
However, the `--shuffle_buffer_size` argument controls how many examples we can pre-download before shuffling them.
```bash
**python -m torch.distributed.launch \
--nproc_per_node 4 run_speech_recognition_ctc_streaming.py \
--dataset_name="common_voice" \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--tokenizer_name_or_path="anton-l/wav2vec2-tokenizer-turkish" \
--dataset_config_name="tr" \
--train_split_name="train+validation" \
--eval_split_name="test" \
--output_dir="wav2vec2-xls-r-common_voice-tr-ft" \
--overwrite_output_dir \
--max_steps="5000" \
--per_device_train_batch_size="8" \
--gradient_accumulation_steps="2" \
--learning_rate="5e-4" \
--warmup_steps="500" \
--evaluation_strategy="steps" \
--text_column_name="sentence" \
--save_steps="500" \
--eval_steps="500" \
--logging_steps="1" \
--layerdrop="0.0" \
--eval_metrics wer cer \
--save_total_limit="1" \
--mask_time_prob="0.3" \
--mask_time_length="10" \
--mask_feature_prob="0.1" \
--mask_feature_length="64" \
--freeze_feature_encoder \
--chars_to_ignore , ? . ! - \; \: \" “ % ‘ ” � \
--max_duration_in_seconds="20" \
--shuffle_buffer_size="500" \
--fp16 \
--push_to_hub \
--do_train --do_eval \
--gradient_checkpointing**
```
On 4 V100 GPUs, this script should run in *ca.* 3h 31min and yield a CTC loss of **0.35** and word error rate
of **0.29**.
### Examples CTC ### Examples CTC
The following tables present a couple of example runs on the most popular speech-recognition datasets. The following tables present a couple of example runs on the most popular speech-recognition datasets.
...@@ -175,6 +231,7 @@ they can serve as a baseline to improve upon. ...@@ -175,6 +231,7 @@ they can serve as a baseline to improve upon.
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.35 | - | 1 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo/blob/main/run.sh) | | [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-large-xlsr-53](https://huggingface.co/facebook/wav2vec2-large-xlsr-53) | 0.35 | - | 1 GPU V100 | 1h20min | [here](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-common_voice-tr-demo/blob/main/run.sh) |
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.31 | - | 8 GPU V100 | 1h05 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft/blob/main/run.sh) | | [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.31 | - | 8 GPU V100 | 1h05 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-300m-common_voice-tr-ft/blob/main/run.sh) |
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | 0.21 | - | 2 GPU Titan 24 GB RAM | 15h10 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xls-r-1b-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-1b-common_voice-tr-ft/blob/main/run.sh) | | [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` | [facebook/wav2vec2-xls-r-1b](https://huggingface.co/facebook/wav2vec2-xls-r-1b) | 0.21 | - | 2 GPU Titan 24 GB RAM | 15h10 | [here](https://huggingface.co/patrickvonplaten/wav2vec2-xls-r-1b-common_voice-tr-ft) | [run.sh](https://huggingface.co/patrickvonplaten/wav2vec2-large-xls-r-1b-common_voice-tr-ft/blob/main/run.sh) |
| [Common Voice](https://huggingface.co/datasets/common_voice)| `"tr"` in streaming mode | [facebook/wav2vec2-xls-r-300m](https://huggingface.co/facebook/wav2vec2-xls-r-300m) | 0.29 | - | 4 GPU V100 | 3h31 | [here](https://huggingface.co/anton-l/wav2vec2-xls-r-common_voice-tr-ft-stream) | [run.sh](https://huggingface.co/anton-l/wav2vec2-xls-r-common_voice-tr-ft-stream/blob/main/run.sh) |
#### Multilingual Librispeech CTC #### Multilingual Librispeech CTC
......
datasets >= 1.13.3 datasets >= 1.18.0
torch >= 1.5 torch >= 1.5
torchaudio torchaudio
librosa librosa
......
...@@ -51,7 +51,7 @@ from transformers.utils.versions import require_version ...@@ -51,7 +51,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.17.0.dev0") check_min_version("4.17.0.dev0")
require_version("datasets>=1.13.3", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -146,7 +146,8 @@ class DataTrainingArguments: ...@@ -146,7 +146,8 @@ class DataTrainingArguments:
train_split_name: str = field( train_split_name: str = field(
default="train+validation", default="train+validation",
metadata={ metadata={
"help": "The name of the training data set split to use (via the datasets library). Defaults to 'train+validation'" "help": "The name of the training data set split to use (via the datasets library). Defaults to "
"'train+validation'"
}, },
) )
eval_split_name: str = field( eval_split_name: str = field(
......
...@@ -49,7 +49,7 @@ from transformers.utils.versions import require_version ...@@ -49,7 +49,7 @@ from transformers.utils.versions import require_version
# Will error if the minimal version of Transformers is not installed. Remove at your own risks. # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
check_min_version("4.17.0.dev0") check_min_version("4.17.0.dev0")
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") require_version("datasets>=1.18.0", "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment