Commit 764b3a75 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
.. Wenet documentation master file, created by
sphinx-quickstart on Thu Dec 3 11:43:53 2020.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Wenet's documentation!
=================================
Wenet is an tansformer-based end-to-end ASR toolkit.
.. toctree::
:maxdepth: 1
:caption: Tutorial:
./python_binding.md
./papers.md
./tutorial_librispeech.md
./tutorial_aishell.md
./pretrained_models.md
./lm.md
./context.md
./runtime.md
./jit_in_wenet.md
./UIO.md
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
# JIT in WeNet
We want that our PyTorch model can be directly exported by torch.jit.script method,
which is essential for deploying the model to production.
See the following resource for how to deploy PyTorch models in production in details.
- [INTRODUCTION TO TORCHSCRIPT](https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html)
- [TORCHSCRIPT LANGUAGE REFERENCE](https://pytorch.org/docs/stable/jit_language_reference.html#language-reference)
- [LOADING A TORCHSCRIPT MODEL IN C++](https://pytorch.org/tutorials/advanced/cpp_export.html)
- [TorchScript and PyTorch JIT | Deep Dive](https://www.youtube.com/watch?v=2awmrMRf0dA&t=574s)
- [Research to Production: PyTorch JIT/TorchScript Updates](https://www.youtube.com/watch?v=St3gdHJzic0)
To ensure that, we will try to export the model before training stage.
If it fails, we should modify the training code to satisfy the export requirements.
``` python
# See in wenet/bin/train.py
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
```
Two principles should be taken into consideration when we contribute our python code
to WeNet, especially for the subclass of torch.nn.Module, and for the forward function.
1. Know what is allowed and what is disallowed.
- [Torch and Tensor Unsupported Attributes](https://pytorch.org/docs/master/jit_unsupported.html#jit-unsupported)
- [Python Language Reference Coverage](https://pytorch.org/docs/master/jit_python_reference.html#python-language-reference)
2. Try to use explicit typing as much as possible. You can try to do type checking
forced by typeguard, see https://typeguard.readthedocs.io/en/latest/userguide.html for details.
# LM for WeNet
WeNet uses n-gram based statistical language model and the WFST framework to support the custom language model.
And LM is only supported in runtime of WeNet.
## Motivation
Why n-gram based LM? This may be the first question many people will ask.
Now that LM based on RNN and Transformer is in full swing, why does WeNet go backward?
The reason is simple, it is for productivity.
The n-gram-based language model has mature and complete training tools,
any amount of corpus can be trained, the training is very fast, the hotfix is easy,
and it has a wide range of mature applications in actual products.
Why WFST? It may be the second question many people will ask.
Since both industry and research have been working so hard to abandon traditional speech recognition,
especially the complex decoding technology. Why does WeNet back?
The reason is also very simple, it is for productivity.
WFST is a standard and powerful tool in traditional speech recognition.
And based on this solution, we have mature and complete bug fix solutions and product solutions,
such as that we can use the replace function in WFST for class-based personalization such as contact recognition.
Therefore, just like WeNet's design goal "Production first and Production Ready",
LM in WeNet also puts productivity as the first priority.
So it draws on many very productive tools and solutions accumulated in traditional speech recognition.
The difference to traditional speech recognition are:
1. The training in WeNet is pure end to end.
2. As described below, LM is optional in decoding, you can choose whether to use LM according to your needs and application scenarios.
## System Design
The whole system is shown in the bellowing picture. There are two ways to generate N-best.
![LM System Design](./images/lm_system.png)
1. Without LM, we use CTC prefix beam search to generate N-best.
2. With LM, we use CTC WFST search to generate N-best and CTC WFST search is the traditional WFST based decoder.
There are two main parts of the CTC WFST based search.
The first is building the decoding graph, which is to compose the model unit T, the lexicon L and the language model G into one unified graph TLG. And in which:
1. T is the model unit in E2E training. Typically it's char in Chinese, char or BPE in English.
2. L is the lexicon, the lexicon is very simple. What we need to do is just split a word into its modeling unit sequence.
For example, the word "我们" is split into two chars "我 们", and the word "APPLE" is split into five letters "A P P L E".
We can see there is no phonemes and there is no need to design pronunciation on purpose.
3. G is the language model, namely compiling the n-gram to standard WFST representation.
The second is the decoder, which is the same as the traditional decoder, which uses the standard Viterbi beam search algorithm in decoding.
## Implementation
WeNet draws on the decoder and related tools in Kaldi to support LM and WFST based decoding.
For ease of using and keeping independence, we directly migrated the code related to decoding in Kaldi to [this directory](https://github.com/wenet-e2e/wenet/tree/main/runtime/core/kaldi) in WeNet runtime.
And modify and organize according to the following principles:
1. To minimize changes, the migrated code remains the same directory structure as the original.
2. We use GLOG to replace the log system in Kaldi.
3. We modify the code format to meet the lint requirements of the code style in WeNet.
The core code is https://github.com/wenet-e2e/wenet/blob/main/runtime/core/decoder/ctc_wfst_beam_search.cc,
which wraps the LatticeFasterDecoder in Kaldi.
And we use blank frame skipping to speed up decoding.
In addition, WeNet also migrated related tools for building the decoding graph,
such as arpa2fst, fstdeterminizestar, fsttablecompose, fstminimizeencoded, and other tools.
So all the tools related to LM are built-in tools and can be used out of the box.
## Results
We get consistent gain (3%~10%) on different datasets,
including aishell, aishell2, and librispeech,
please go to the corresponding example dataset for the details.
## How to use?
Here is an example from aishell, which shows how to prepare the dictionary, how to train the LM,
how to build the graph, and how to decode with the runtime.
``` sh
# 7.1 Prepare dict
unit_file=$dict
mkdir -p data/local/dict
cp $unit_file data/local/dict/units.txt
tools/fst/prepare_dict.py $unit_file ${data}/resource_aishell/lexicon.txt \
data/local/dict/lexicon.txt
# 7.2 Train lm
lm=data/local/lm
mkdir -p $lm
tools/filter_scp.pl data/train/text \
$data/data_aishell/transcript/aishell_transcript_v0.8.txt > $lm/text
local/aishell_train_lms.sh
# 7.3 Build decoding TLG
tools/fst/compile_lexicon_token_fst.sh \
data/local/dict data/local/tmp data/local/lang
tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
# 7.4 Decoding with runtime
./tools/decode.sh --nj 16 \
--beam 15.0 --lattice_beam 7.5 --max_active 7000 \
--blank_skip_thresh 0.98 --ctc_weight 0.5 --rescoring_weight 1.0 \
--fst_path data/lang_test/TLG.fst \
--dict_path data/lang_test/words.txt \
data/test/wav.scp data/test/text $dir/final.zip \
data/lang_test/units.txt $dir/lm_with_runtime
```
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
:end
popd
## Papers
* [WeNet: Production Oriented Streaming and Non-streaming End-to-End Speech Recognition Toolkit](https://arxiv.org/pdf/2102.01547.pdf), accepted by InterSpeech 2021.
* [WeNet 2.0: More Productive End-to-End Speech Recognition Toolkit](https://arxiv.org/pdf/2203.15455.pdf), accepted by InterSpeech 2022.
# Pretrained Models in WeNet
## Model Types
We provide two types of pretrained model in WeNet to facilitate users with different requirements.
1. **Checkpoint Model**, with suffix **.pt**, the model trained and saved as checkpoint by WeNet python code, you can reproduce our published result with it, or you can use it as checkpoint to continue.
2. **Runtime Model**, with suffix **.zip**, you can directly use `runtime model` in our [x86](https://github.com/wenet-e2e/wenet/tree/main/runtime/libtorch) or [android](https://github.com/wenet-e2e/wenet/tree/main/runtime/android) runtime, the `runtime model` is export by Pytorch JIT on the `checkpoint model`. And the runtime models has been quantized to reduce the model size and network traffic.
## Model License
The pretrained model in WeNet follows the license of it's corresponding dataset.
For example, the pretrained model on LibriSpeech follows `CC BY 4.0`, since it is used as license of the LibriSpeech dataset, see http://openslr.org/12/.
## Model List
Here is a list of the pretrained models on different datasets. The model structure, model size, and download link are given.
| Datasets | Languages | Checkpoint Model | Runtime Model | Contributor |
|--- |--- |--- |--- |--- |
| [aishell](../examples/aishell/s0/README.md) | CN | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/aishell/20210601_u2%2B%2B_conformer_exp.tar.gz) | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/aishell/20210601_u2%2B%2B_conformer_libtorch.tar.gz) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [aishell2](../examples/aishell2/s0/README.md) | CN | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/aishell2/20210618_u2pp_conformer_exp.tar.gz) | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/aishell2/20210618_u2pp_conformer_libtorch.tar.gz) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [gigaspeech](../examples/gigaspeech/s0/README.md) | EN | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/gigaspeech/gigaspeech_u2pp_conformer_exp.tar.gz) | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/gigaspeech/gigaspeech_u2pp_conformer_libtorch.tar.gz) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [librispeech](../examples/librispeech/s0/README.md) | EN | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/librispeech/20210610_u2pp_conformer_exp.tar.gz) | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/librispeech/20210610_u2pp_conformer_libtorch.tar.gz) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [multi_cn](../examples/multi_cn/s0/README.md) | CN | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/multi_cn/20210815_unified_conformer_exp.tar.gz) | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/multi_cn/20210815_unified_conformer_libtorch.tar.gz) | <a href="https://www.jd.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/jd.jpeg" width="100px"></a> |
| [wenetspeech](../examples/wenetspeech/s0/README.md) | CN | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/wenetspeech/wenetspeech_u2pp_conformer_exp.tar.gz) | [Conformer](https://wenet-1256283475.cos.ap-shanghai.myqcloud.com/models/wenetspeech/wenetspeech_u2pp_conformer_libtorch.tar.gz) | <a href="https://horizon.ai" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/hobot.png" width="100px"></a> |
# Pretrained Models in WeNet
## Model Types
We provide two types of pretrained model in WeNet to facilitate users with different requirements.
1. **Checkpoint Model**, with suffix **.pt**, the model trained and saved as checkpoint by WeNet python code, you can reproduce our published result with it, or you can use it as checkpoint to continue.
2. **Runtime Model**, with suffix **.zip**, you can directly use `runtime model` in our [x86](https://github.com/wenet-e2e/wenet/tree/main/runtime/libtorch) or [android](https://github.com/wenet-e2e/wenet/tree/main/runtime/android) runtime, the `runtime model` is export by Pytorch JIT on the `checkpoint model`. And the runtime models has been quantized to reduce the model size and network traffic.
## Model License
The pretrained model in WeNet follows the license of it's corresponding dataset.
For example, the pretrained model on LibriSpeech follows `CC BY 4.0`, since it is used as license of the LibriSpeech dataset, see http://openslr.org/12/.
## Model List
Here is a list of the pretrained models on different datasets.
For non-Chinese users, please visit [Pretrained Models(En)](./pretrained_models.en.md) to download.
| Datasets | Languages | Checkpoint Model | Runtime Model | Contributor |
|--- |--- |--- |--- |--- |
| [aishell](../examples/aishell/s0/README.md) | CN | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [aishell2](../examples/aishell2/s0/README.md) | CN | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [gigaspeech](../examples/gigaspeech/s0/README.md) | EN | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [librispeech](../examples/librispeech/s0/README.md) | EN | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | <a href="https://www.chumenwenwen.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/chumenwenwen.png" width="100px"></a> |
| [multi_cn](../examples/multi_cn/s0/README.md) | CN | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | <a href="https://www.jd.com" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/jd.jpeg" width="100px"></a> |
| [wenetspeech](../examples/wenetspeech/s0/README.md) | CN | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | [Conformer](https://docs.qq.com/form/page/DZnRkVHlnUk5QaFdC) | <a href="https://horizon.ai" target="_blank"><img src="https://raw.githubusercontent.com/wenet-e2e/wenet-contributors/main/companies/hobot.png" width="100px"></a> |
../runtime/binding/python/README.md
\ No newline at end of file
# Runtime for WeNet
WeNet runtime uses [Unified Two Pass (U2)](https://arxiv.org/pdf/2102.01547.pdf) framework for inference. U2 has the following advantages:
* **Unified**: U2 unified the streaming and non-streaming model in a simple way, and our runtime is also unified. Therefore you can easily balance the latency and accuracy by changing chunk_size (described in the following section).
* **Accurate**: U2 achieves better accuracy by CTC joint training.
* **Fast**: Our runtime uses attention rescoring based decoding method described in U2, which is much faster than a traditional autoregressive beam search.
* **Other benefits**: In practice, we find U2 is more stable on long-form speech than standard transformer which usually fails or degrades a lot on long-form speech; and we can easily get the word-level time stamps by CTC spikes in U2. Both of these aspects are favored for industry adoption.
## Platforms Supported
The WeNet runtime supports the following platforms.
* Server
* [x86](https://github.com/wenet-e2e/wenet/tree/main/runtime/libtorch)
* Device
* [android](https://github.com/wenet-e2e/wenet/tree/main/runtime/android)
## Architecture and Implementation
### Architecture
The following picture shows how U2 works.
![U2](images/u2.gif)
When input is not finished, the input frames $x_t$ are fed into the *Shared Encoder* module frame by frame to get the encoder output $e_t$, then $e_t$ is transformed by the *CTC Activation* module (typically, it's just a linear transform with a log_softmax) to get the CTC prob $y_t$ at current frame, and $y_t$ is further used by the *CTC prefix beam search* module to generate n-best results at current time $t$, and the best result is used as partial result of the U2 system.
When input is finished at time $T$, the n-best results from the *CTC prefix beam search* module and the encoder output $e_1, e_2, e_3, ..., e_T$ are fed into the *Attention Decoder* module, then the *Attention Decoder* module computes a score for every result. The result with the best score is selected as the final result of U2 system.
We can group $C$ continuous frames $x_t, x_{t+1}, x_{t+C}$ as one chunk for the *Shared Encoder* module, and $C$ is called chunk_size in the U2 framework. The chunk_size will affect the attention computation in the *Shared Encoder* module. When chunk_size is infinite, it is a non-streaming case. The system gives the best accuracy with infinite latency. When chunk_size is limited (typically less than 1s), it is a streaming case. The system has limited latency and also gives promising accuracy. So the developer can balance the accuracy and latency and setting a proper chunk_size.
### Interface Design
We use LibTorch to implement U2 runtime in WeNet, and we export several interfaces in PyTorch python code
by @torch.jit.export (see [asr_model.py](https://github.com/wenet-e2e/wenet/tree/main/wenet/transformer/asr_model.py)),
which are required and used in C++ runtime in [torch_asr_model.cc](https://github.com/wenet-e2e/wenet/tree/main/runtime/libtorch/decoder/torch_asr_model.cc).
Here we just list the interface and give a brief introduction.
| interface | description |
|----------------------------------|-----------------------------------------|
| subsampling_rate (args) | get the subsampling rate of the model |
| right_context (args) | get the right context of the model |
| sos_symbol (args) | get the sos symbol id of the model |
| eos_symbol (args) | get the eos symbol id of the model |
| forward_encoder_chunk (args) | used for the *Shared Encoder* module |
| ctc_activation (args) | used for the *CTC Activation* module |
| forward_attention_decoder (args) | used for the *Attention Decoder* module |
### Cache in Details
For streaming scenario, the *Shared Encoder* module works in an incremental way. The current chunk computation requries the inputs and outputs of all the history chunks. We implement the incremental computation by using caches. Overall, two types of cache are used in our runtime.
* att_cache: the attention cache of the *Shared Encoder*(Conformer/Transformer) module.
* cnn_cache: the cnn cache of the *Shared Encoder*, which caches the left context for causal CNN computation in Conformer.
Please see [encoder.py:forward_chunk()](https://github.com/wenet-e2e/wenet/tree/main/wenet/transformer/encoder.py) and [torch_asr_model.cc](https://github.com/wenet-e2e/wenet/tree/main/runtime/libtorch/decoder/torch_asr_model.cc) for details of the caches.
In practice, CNN is also used for subsampling, we should handle the CNN cache in subsampling.
However, there are different CNN layers in subsampling with different left contexts, right contexts and strides, which makes it tircky to directly implement the CNN cache in subsampling.
In our implementation, we simply overlap the input to avoid subsampling CNN cache.
It is simple and straightforward with negligible additional cost since subsampling CNN only costs a very small fraction of the whole computation.
The following picture shows how it works, where the blue color is for the overlap part of current inputs and previous inputs.
![Overlap input for Subsampling CNN](images/subsampling_overalp.gif)
## References
1. [Sequence Modeling With CTC](https://distill.pub/2017/ctc/)
2. [First-Pass Large Vocabulary Continuous Speech Recognition using Bi-Directional Recurrent DNNs](https://arxiv.org/pdf/1408.2873.pdf)
3. [Unified Streaming and Non-streaming Two-pass End-to-end Model for Speech Recognition](https://arxiv.org/pdf/2012.05481.pdf)
## Tutorial on AIShell
If you meet any problems when going through this tutorial, please feel free to ask in github [issues](https://github.com/mobvoi/wenet/issues). Thanks for any kind of feedback.
### Setup environment
Please follow [Installation](https://github.com/wenet-e2e/wenet#installation) to install WeNet.
### First Experiment
We provide a recipe `example/aishell/s0/run.sh` on aishell-1 data.
The recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process.
```
cd example/aishell/s0
bash run.sh --stage -1 --stop_stage -1
bash run.sh --stage 0 --stop_stage 0
bash run.sh --stage 1 --stop_stage 1
bash run.sh --stage 2 --stop_stage 2
bash run.sh --stage 3 --stop_stage 3
bash run.sh --stage 4 --stop_stage 4
bash run.sh --stage 5 --stop_stage 5
bash run.sh --stage 6 --stop_stage 6
```
You could also just run the whole script
```
bash run.sh --stage -1 --stop_stage 6
```
#### Stage -1: Download data
This stage downloads the aishell-1 data to the local path `$data`. This may take several hours. If you have already downloaded the data, please change the `$data` variable in `run.sh` and start from `--stage 0`.
Please set a **absolute path** for `$data`, e.g. `/home/username/asr-data/aishell/`
#### Stage 0: Prepare Training data
In this stage, `local/aishell_data_prep.sh` organizes the original aishell-1 data into two files:
* **wav.scp** each line records two tab-separated columns : `wav_id` and `wav_path`
* **text** each line records two tab-separated columns : `wav_id` and `text_label`
**wav.scp**
```
BAC009S0002W0122 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0122.wav
BAC009S0002W0123 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0123.wav
BAC009S0002W0124 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0124.wav
BAC009S0002W0125 /export/data/asr-data/OpenSLR/33/data_aishell/wav/train/S0002/BAC009S0002W0125.wav
...
```
**text**
```
BAC009S0002W0122 而对楼市成交抑制作用最大的限购
BAC009S0002W0123 也成为地方政府的眼中钉
BAC009S0002W0124 自六月底呼和浩特市率先宣布取消限购后
BAC009S0002W0125 各地政府便纷纷跟进
...
```
If you want to train using your customized data, just organize the data into two files `wav.scp` and `text`, and start from `stage 1`.
#### Stage 1: Extract optinal cmvn features
`example/aishell/s0` uses raw wav as input and and [TorchAudio](https://pytorch.org/audio/stable/index.html) to extract the features just-in-time in dataloader. So in this step we just copy the training wav.scp and text file into the `raw_wav/train/` dir.
`tools/compute_cmvn_stats.py` is used to extract global cmvn(cepstral mean and variance normalization) statistics. These statistics will be used to normalize the acoustic features. Setting `cmvn=false` will skip this step.
#### Stage 2: Generate label token dictionary
The dict is a map between label tokens (we use characters for Aishell-1) and
the integer indices.
An example dict is as follows
```
<blank> 0
<unk> 1
一 2
丁 3
...
龚 4230
龟 4231
<sos/eos> 4232
```
* `<blank>` denotes the blank symbol for CTC.
* `<unk>` denotes the unknown token, any out-of-vocabulary tokens will be mapped into it.
* `<sos/eos>` denotes start-of-speech and end-of-speech symbols for attention based encoder decoder training, and they shares the same id.
#### Stage 3: Prepare WeNet data format
This stage generates the WeNet required format file `data.list`. Each line in `data.list` is in json format which contains the following fields.
1. `key`: key of the utterance
2. `wav`: audio file path of the utterance
3. `txt`: normalized transcription of the utterance, the transcription will be tokenized to the model units on-the-fly at the training stage.
Here is an example of the `data.list`, and please see the generated training feature file in `data/train/data.list`.
```
{"key": "BAC009S0002W0122", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0122.wav", "txt": "而对楼市成交抑制作用最大的限购"}
{"key": "BAC009S0002W0123", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0123.wav", "txt": "也成为地方政府的眼中钉"}
{"key": "BAC009S0002W0124", "wav": "/export/data/asr-data/OpenSLR/33//data_aishell/wav/train/S0002/BAC009S0002W0124.wav", "txt": "自六月底呼和浩特市率先宣布取消限购后"}
```
We aslo design another format for `data.list` named `shard` which is for big data training.
Please see [gigaspeech](https://github.com/wenet-e2e/wenet/tree/main/examples/gigaspeech/s0)(10k hours) or
[wenetspeech](https://github.com/wenet-e2e/wenet/tree/main/examples/wenetspeech/s0)(10k hours)
for how to use `shard` style `data.list` if you want to apply WeNet on big data set(more than 5k).
#### Stage 4: Neural Network training
The NN model is trained in this step.
- Multi-GPU mode
If using DDP mode for multi-GPU, we suggest using `dist_backend="nccl"`. If the NCCL does not work, try using `gloo` or use `torch==1.6.0`
Set the GPU ids in CUDA_VISIBLE_DEVICES. For example, set `export CUDA_VISIBLE_DEVICES="0,1,2,3,6,7"` to use card 0,1,2,3,6,7.
- Resume training
If your experiment is terminated after running several epochs for some reasons (e.g. the GPU is accidentally used by other people and is out-of-memory ), you could continue the training from a checkpoint model. Just find out the finished epoch in `exp/your_exp/`, set `checkpoint=exp/your_exp/$n.pt` and run the `run.sh --stage 4`. Then the training will continue from the $n+1.pt
- Config
The config of neural network structure, optimization parameter, loss parameters, and dataset can be set in a YAML format file.
In `conf/`, we provide several models like transformer and conformer. see `conf/train_conformer.yaml` for reference.
- Use Tensorboard
The training takes several hours. The actual time depends on the number and type of your GPU cards. In an 8-card 2080 Ti machine, it takes about less than one day for 50 epochs.
You could use tensorboard to monitor the loss.
```
tensorboard --logdir tensorboard/$your_exp_name/ --port 12598 --bind_all
```
#### Stage 5: Recognize wav using the trained model
This stage shows how to recognize a set of wavs into texts. It also shows how to do the model averaging.
- Average model
If `${average_checkpoint}` is set to `true`, the best `${average_num}` models on cross validation set will be averaged to generate a boosted model and used for recognition.
- Decoding
Recognition is also called decoding or inference. The function of the NN will be applied on the input acoustic feature sequence to output a sequence of text.
Four decoding methods are provided in WeNet:
* `ctc_greedy_search` : encoder + CTC greedy search
* `ctc_prefix_beam_search` : encoder + CTC prefix beam search
* `attention` : encoder + attention-based decoder decoding
* `attention_rescoring` : rescoring the ctc candidates from ctc prefix beam search with encoder output on attention-based decoder.
In general, attention_rescoring is the best method. Please see [U2 paper](https://arxiv.org/pdf/2012.05481.pdf) for the details of these algorithms.
`--beam_size` is a tunable parameter, a large beam size may get better results but also cause higher computation cost.
`--batch_size` can be greater than 1 for "ctc_greedy_search" and "attention" decoding mode, and must be 1 for "ctc_prefix_beam_search" and "attention_rescoring" decoding mode.
- WER evaluation
`tools/compute-wer.py` will calculate the word (or char) error rate of the result. If you run the recipe without any change, you may get WER ~= 5%.
#### Stage 6: Export the trained model
`wenet/bin/export_jit.py` will export the trained model using Libtorch. The exported model files can be easily used for inference in other programming languages such as C++.
## Tutorial on LibriSpeech
If you meet any problems when going through this tutorial, please feel free to ask in github [issues](https://github.com/mobvoi/wenet/issues). Thanks for any kind of feedback.
### Setup environment
Please follow [Installation](https://github.com/wenet-e2e/wenet#installation) to install WeNet.
### First Experiment
We provide a recipe `example/librispeech/s0/run.sh` on librispeech data.
The recipe is simple and we suggest you run each stage one by one manually and check the result to understand the whole process.
```
cd example/librispeech/s0
bash run.sh --stage -1 --stop_stage -1
bash run.sh --stage 0 --stop_stage 0
bash run.sh --stage 1 --stop_stage 1
bash run.sh --stage 2 --stop_stage 2
bash run.sh --stage 3 --stop_stage 3
bash run.sh --stage 4 --stop_stage 4
bash run.sh --stage 5 --stop_stage 5
bash run.sh --stage 6 --stop_stage 6
bash run.sh --stage 7 --stop_stage 7
```
You could also just run the whole script
```
bash run.sh --stage -1 --stop_stage 7
```
#### Stage -1: Download data
``` sh
data_url=www.openslr.org/resources/12
datadir=/export/data/en-asr-data/OpenSLR/
if [ ${stage} -le -1 ] && [ ${stop_stage} -ge -1 ]; then
echo "stage -1: Data Download"
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
local/download_and_untar.sh ${datadir} ${data_url} ${part}
done
fi
```
This stage downloads the librispeech data to the local path `$data`. This may take several hours. If you have already downloaded the data, please change the `$data` variable in `run.sh` and start from `--stage 0`.
#### Stage 0: Prepare Training data
``` sh
wave_data=data
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
# use underscore-separated names in data directories.
local/data_prep_torchaudio.sh ${datadir}/LibriSpeech/${part} $wave_data/${part//-/_}
done
fi
```
In this stage, `local/data_prep_torchaudio.sh` organizes the original data into two files:
* **wav.scp** each line records two tab-separated columns : `wav_id` and `wav_path`
* **text** each line records two tab-separated columns : `wav_id` and `text_label`
**wav.scp**
```
1867-154075-0014 /export/data/en-asr-data/OpenSLR//LibriSpeech/train-clean-100/1867/154075/1867-154075-0014.flac
1970-26100-0022 /export/data/en-asr-data/OpenSLR//LibriSpeech/train-clean-100/1970/26100/1970-26100-0022.flac
...
```
**text**
```
1867-154075-0014 YOU SHOW HIM THAT IT IS POSSIBLE
1970-26100-0022 DID YOU SEE HIM AT THAT TIME
...
```
If you want to train using your customized data, just organize the data into two files `wav.scp` and `text`, and start from `stage 1`.
#### Stage 1: Extract optinal cmvn features
``` sh
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
### Task dependent. You have to design training and dev sets by yourself.
### But you can utilize Kaldi recipes in most cases
echo "stage 1: Feature Generation"
mkdir -p $wave_data/train_960
# merge total training data
for set in train_clean_100 train_clean_360 train_other_500; do
for f in `ls $wave_data/$set`; do
cat $wave_data/$set/$f >> $wave_data/train_960/$f
done
done
mkdir -p $wave_data/dev
# merge total dev data
for set in dev_clean dev_other; do
for f in `ls $wave_data/$set`; do
cat $wave_data/$set/$f >> $wave_data/dev/$f
done
done
tools/compute_cmvn_stats.py --num_workers 16 --train_config $train_config \
--in_scp $wave_data/$train_set/wav.scp \
--out_cmvn $wave_data/$train_set/global_cmvn
fi
```
The librispeech corpus contains 3 subsets for training, namely `train_clean_100`, `train_clean_360`, and `train_other_500`,
so we first merge them to get our final training data.
`tools/compute_cmvn_stats.py` is used to extract global cmvn(cepstral mean and variance normalization) statistics. These statistics will be used to normalize the acoustic features. Setting `cmvn=false` will skip this step.
#### Stage 2: Generate label token dictionary
``` sh
dict=$wave_data/lang_char/${train_set}_${bpemode}${nbpe}_units.txt
bpemodel=$wave_data/lang_char/${train_set}_${bpemode}${nbpe}
echo "dictionary: ${dict}"
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
### Task dependent. You have to check non-linguistic symbols used in the corpus.
echo "stage 2: Dictionary and Json Data Preparation"
mkdir -p data/lang_char/
echo "<blank> 0" > ${dict} # 0 will be used for "blank" in CTC
echo "<unk> 1" >> ${dict} # <unk> must be 1
# we borrowed these code and scripts which are related bpe from ESPnet.
cut -f 2- -d" " $wave_data/${train_set}/text > $wave_data/lang_char/input.txt
tools/spm_train --input=$wave_data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000
tools/spm_encode --model=${bpemodel}.model --output_format=piece < $wave_data/lang_char/input.txt | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+1}' >> ${dict}
num_token=$(cat $dict | wc -l)
echo "<sos/eos> $num_token" >> $dict # <eos>
wc -l ${dict}
fi
```
The model unit of English e2e speech recognition system could be char or BPE(byte-pair-encoding).
Typically, BPE shows better result. So here we use BPE as model unit,
and the BPE is trained by [sentencepiece](https://github.com/google/sentencepiece) tool on the librispeech training data.
The model unit is defined as a dict in WeNet, which maps the a BPE into integer index.
The librispeech dict is like:
```
<blank> 0
<unk> 1
' 2
▁ 3
A 4
▁A 5
AB 6
▁AB 7
▁YOU 4995
▁YOUNG 4996
▁YOUR 4997
▁YOUTH 4998
Z 4999
ZZ 5000
<sos/eos> 5001
```
* `<blank>` denotes the blank symbol for CTC.
* `<unk>` denotes the unknown token, any out-of-vocabulary tokens will be mapped into it.
* `<sos/eos>` denotes start-of-speech and end-of-speech symbols for attention based encoder decoder training, and they shares the same id.
#### Stage 3: Prepare WeNet data format
``` sh
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
# Prepare wenet required data
echo "Prepare data, prepare required format"
for x in dev ${recog_set} $train_set ; do
tools/make_raw_list.py $wave_data/$x/wav.scp $wave_data/$x/text \
$wave_data/$x/data.list
done
fi
```
This stage generates the WeNet required format file `data.list`. Each line in `data.list` is in json format which contains the following fields.
1. `key`: key of the utterance
2. `wav`: audio file path of the utterance
3. `txt`: normalized transcription of the utterance, the transcription will be tokenized to the model units on-the-fly at the training stage.
Here is an example of the `data.list`, and please see the generated training feature file in `data/train/data.list`.
```
{"key": "1455-134435-0000", "wav": "/mnt/nfs/ptm1/open-data/LibriSpeech/train-clean-100/1455/134435/1455-134435-0000.flac", "txt": "THE GIRL WHO CAME INTO THE WORLD ON THAT NIGHT WHEN JESSE RAN THROUGH THE FIELDS CRYING TO GOD THAT HE BE GIVEN A SON HAD GROWN TO WOMANHOOD ON THE FARM"}
{"key": "1455-134435-0001", "wav": "/mnt/nfs/ptm1/open-data/LibriSpeech/train-clean-100/1455/134435/1455-134435-0001.flac", "txt": "AND WHEN NOT ANGRY SHE WAS OFTEN MOROSE AND SILENT IN WINESBURG IT WAS SAID THAT SHE DRANK HER HUSBAND THE BANKER"}
{"key": "1455-134435-0002", "wav": "/mnt/nfs/ptm1/open-data/LibriSpeech/train-clean-100/1455/134435/1455-134435-0002.flac", "txt": "BUT LOUISE COULD NOT BE MADE HAPPY SHE FLEW INTO HALF INSANE FITS OF TEMPER DURING WHICH SHE WAS SOMETIMES SILENT SOMETIMES NOISY AND QUARRELSOME SHE SWORE AND CRIED OUT IN HER ANGER SHE GOT A KNIFE FROM THE KITCHEN AND THREATENED HER HUSBAND'S LIFE"}
```
We aslo design another format for `data.list` named `shard` which is for big data training.
Please see [gigaspeech](https://github.com/wenet-e2e/wenet/tree/main/examples/gigaspeech/s0)(10k hours) or
[wenetspeech](https://github.com/wenet-e2e/wenet/tree/main/examples/wenetspeech/s0)(10k hours)
for how to use `shard` style `data.list` if you want to apply WeNet on big data set(more than 5k).
#### Stage 4: Neural Network training
``` sh
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
# Training
mkdir -p $dir
INIT_FILE=$dir/ddp_init
rm -f $INIT_FILE # delete old one before starting
init_method=file://$(readlink -f $INIT_FILE)
echo "$0: init method is $init_method"
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
# Use "nccl" if it works, otherwise use "gloo"
dist_backend="nccl"
cmvn_opts=
$cmvn && cmvn_opts="--cmvn $wave_data/${train_set}/global_cmvn"
# train.py will write $train_config to $dir/train.yaml with model input
# and output dimension, train.yaml will be used for inference or model
# export later
for ((i = 0; i < $num_gpus; ++i)); do
{
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$i+1])
python wenet/bin/train.py --gpu $gpu_id \
--config $train_config \
--data_type raw \
--symbol_table $dict \
--train_data $wave_data/$train_set/data.list \
--cv_data $wave_data/dev/data.list \
${checkpoint:+--checkpoint $checkpoint} \
--model_dir $dir \
--ddp.init_method $init_method \
--ddp.world_size $num_gpus \
--ddp.rank $i \
--ddp.dist_backend $dist_backend \
--num_workers 1 \
$cmvn_opts \
--pin_memory
} &
done
wait
fi
```
The NN model is trained in this step.
- Multi-GPU mode
If using DDP mode for multi-GPU, we suggest using `dist_backend="nccl"`. If the NCCL does not work, try using `gloo` or use `torch==1.6.0`
Set the GPU ids in CUDA_VISIBLE_DEVICES. For example, set `export CUDA_VISIBLE_DEVICES="0,1,2,3,6,7"` to use card 0,1,2,3,6,7.
- Resume training
If your experiment is terminated after running several epochs for some reasons (e.g. the GPU is accidentally used by other people and is out-of-memory ), you could continue the training from a checkpoint model. Just find out the finished epoch in `exp/your_exp/`, set `checkpoint=exp/your_exp/$n.pt` and run the `run.sh --stage 4`. Then the training will continue from the $n+1.pt
- Config
The config of neural network structure, optimization parameter, loss parameters, and dataset can be set in a YAML format file.
In `conf/`, we provide several models like transformer and conformer. see `conf/train_conformer.yaml` for reference.
- Use Tensorboard
The training takes several hours. The actual time depends on the number and type of your GPU cards. In an 8-card 2080 Ti machine, it takes about less than one day for 50 epochs.
You could use tensorboard to monitor the loss.
```
tensorboard --logdir tensorboard/$your_exp_name/ --port 12598 --bind_all
```
#### Stage 5: Recognize wav using the trained model
``` sh
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
# Test model, please specify the model you want to test by --checkpoint
cmvn_opts=
$cmvn && cmvn_opts="--cmvn data/${train_set}/global_cmvn"
# TODO, Add model average here
mkdir -p $dir/test
if [ ${average_checkpoint} == true ]; then
decode_checkpoint=$dir/avg_${average_num}.pt
echo "do model average and final checkpoint is $decode_checkpoint"
python wenet/bin/average_model.py \
--dst_model $decode_checkpoint \
--src_path $dir \
--num ${average_num} \
--val_best
fi
# Specify decoding_chunk_size if it's a unified dynamic chunk trained model
# -1 for full chunk
decoding_chunk_size=
ctc_weight=0.5
# Polling GPU id begin with index 0
num_gpus=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
idx=0
for test in $recog_set; do
for mode in ${decode_modes}; do
{
{
test_dir=$dir/${test}_${mode}
mkdir -p $test_dir
gpu_id=$(echo $CUDA_VISIBLE_DEVICES | cut -d',' -f$[$idx+1])
python wenet/bin/recognize.py --gpu $gpu_id \
--mode $mode \
--config $dir/train.yaml \
--data_type raw \
--test_data $wave_data/$test/data.list \
--checkpoint $decode_checkpoint \
--beam_size 10 \
--batch_size 1 \
--penalty 0.0 \
--dict $dict \
--result_file $test_dir/text_bpe \
--ctc_weight $ctc_weight \
${decoding_chunk_size:+--decoding_chunk_size $decoding_chunk_size}
tools/spm_decode --model=${bpemodel}.model --input_format=piece < $test_dir/text_bpe | sed -e "s/▁/ /g" > $test_dir/text
python tools/compute-wer.py --char=1 --v=1 \
$wave_data/$test/text $test_dir/text > $test_dir/wer
} &
((idx+=1))
if [ $idx -eq $num_gpus ]; then
idx=0
fi
}
done
done
wait
fi
```
This stage shows how to recognize a set of wavs into texts. It also shows how to do the model averaging.
- Average model
If `${average_checkpoint}` is set to `true`, the best `${average_num}` models on cross validation set will be averaged to generate a boosted model and used for recognition.
- Decoding
Recognition is also called decoding or inference. The function of the NN will be applied on the input acoustic feature sequence to output a sequence of text.
Four decoding methods are provided in WeNet:
* `ctc_greedy_search` : encoder + CTC greedy search
* `ctc_prefix_beam_search` : encoder + CTC prefix beam search
* `attention` : encoder + attention-based decoder decoding
* `attention_rescoring` : rescoring the ctc candidates from ctc prefix beam search with encoder output on attention-based decoder.
In general, attention_rescoring is the best method. Please see [U2 paper](https://arxiv.org/pdf/2012.05481.pdf) for the details of these algorithms.
`--beam_size` is a tunable parameter, a large beam size may get better results but also cause higher computation cost.
`--batch_size` can be greater than 1 for "ctc_greedy_search" and "attention" decoding mode, and must be 1 for "ctc_prefix_beam_search" and "attention_rescoring" decoding mode.
- WER evaluation
`tools/compute-wer.py` will calculate the word (or char) error rate of the result.
#### Stage 6(Optional): Export the trained model
``` sh
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
# Export the best model you want
python wenet/bin/export_jit.py \
--config $dir/train.yaml \
--checkpoint $dir/avg_${average_num}.pt \
--output_file $dir/final.zip
fi
```
`wenet/bin/export_jit.py` will export the trained model using Libtorch.
The exported model files can be easily used for C++ inference in our runtime.
It is required if you want to integrate language model(LM), as shown in Stage 7.
#### Stage 7(Optional): Add LM and test it with runtime
``` sh
if [ ${stage} -le 7 ] && [ ${stop_stage} -ge 7 ]; then
lm=data/local/lm
lexicon=data/local/dict/lexicon.txt
mkdir -p $lm
mkdir -p data/local/dict
# 7.1 Download & format LM
which_lm=3-gram.pruned.1e-7.arpa.gz
if [ ! -e ${lm}/${which_lm} ]; then
wget http://www.openslr.org/resources/11/${which_lm} -P ${lm}
fi
echo "unzip lm($which_lm)..."
gunzip -k ${lm}/${which_lm} -c > ${lm}/lm.arpa
echo "Lm saved as ${lm}/lm.arpa"
# 7.2 Prepare dict
unit_file=$dict
bpemodel=$bpemodel
# use $dir/words.txt (unit_file) and $dir/train_960_unigram5000 (bpemodel)
# if you download pretrained librispeech conformer model
cp $unit_file data/local/dict/units.txt
if [ ! -e ${lm}/librispeech-lexicon.txt ]; then
wget http://www.openslr.org/resources/11/librispeech-lexicon.txt -P ${lm}
fi
echo "build lexicon..."
tools/fst/prepare_dict.py $unit_file ${lm}/librispeech-lexicon.txt \
$lexicon $bpemodel.model
echo "lexicon saved as '$lexicon'"
# 7.3 Build decoding TLG
tools/fst/compile_lexicon_token_fst.sh \
data/local/dict data/local/tmp data/local/lang
tools/fst/make_tlg.sh data/local/lm data/local/lang data/lang_test || exit 1;
# 7.4 Decoding with runtime
fst_dir=data/lang_test
for test in ${recog_set}; do
./tools/decode.sh --nj 6 \
--beam 10.0 --lattice_beam 5 --max_active 7000 --blank_skip_thresh 0.98 \
--ctc_weight 0.5 --rescoring_weight 1.0 --acoustic_scale 1.2 \
--fst_path $fst_dir/TLG.fst \
--dict_path $fst_dir/words.txt \
data/$test/wav.scp data/$test/text $dir/final.zip $fst_dir/units.txt \
$dir/lm_with_runtime_${test}
tail $dir/lm_with_runtime_${test}/wer
done
fi
```
LM is only supported in runtime, you have to build the runtime as shown in [Installation](https://github.com/wenet-e2e/wenet#installation),
and please refer [LM for WeNet](https://wenet-e2e.github.io/wenet/lm.html) for the details of LM design.
# Recipe to run Noisy Student Training with LM filter in WeNet
Noisy Student Training (NST) has recently demonstrated extremely strong performance in Automatic Speech Recognition (ASR).
Here, we provide a recipe to run NST with `LM filter` strategy using AISHELL-1 as supervised data and WenetSpeech as unsupervised data from [this paper](https://arxiv.org/abs/2211.04717), where hypotheses with and without Language Model are generated and CER differences between them are utilized as a filter threshold to improve the ASR performances of non-target domain datas.
## Table of Contents
- [Guideline](#guideline)
- [Data preparation](#data-preparation)
- [Initial supervised teacher](#initial-supervised-teacher)
- [Noisy student interations](#noisy-student-interations)
- [Performance Record](#performance-record)
- [Supervised baseline and standard NST](##supervised-baseline-and-standard-nst)
- [Supervised AISHELL-1 and unsupervised 1khr WenetSpeech](#supervised-aishell-1-and-unsupervised-1khr-wenetspeech)
- [Supervised AISHELL-2 and unsupervised 4khr WenetSpeech](#supervised-aishell-2-and-unsupervised-4khr-wenetspeech)
- [Citations](#citations)
## Guideline
First, you have to prepare supervised and unsupervised data for NST. Then in stage 1 of `run.sh`, you will train an initial supervised teacher and generate pseudo labels for unsupervised data.
After that, you can run the noisy student training iteratively in stage 2. The whole pipeline is illustrated in the following picture.
![plot](local/NST_plot.png)
### Data preparation
To run this recipe, you should follow the steps from [WeNet examples](https://github.com/wenet-e2e/wenet/tree/main/examples) to prepare [AISHELL1](https://github.com/wenet-e2e/wenet/tree/main/examples/aishell/s0) and [WenetSpeech](https://github.com/wenet-e2e/wenet/tree/main/examples/wenetspeech/s0) data.
We extract 1khr data from WenetSpeech and data should be prepared and stored in the following format:
```
data/
├── train/
├──── data_aishell.list
├──── wenet_1khr.list
├──── wav_dir/
├──── utter_time.json (optional)
├── dev/
└── test/
```
- Files `*.list` contain paths for all the data shards for training.
- A Json file containing the audio length should be prepared as `utter_time.json` if you want to apply the `speaking rate` filter.
- A wav_dir contains all the audio data (id.wav) and labels (id.txt which is optional) for unsupervised data.
### Initial supervised teacher
To train an initial supervised teacher model, run the following command:
```bash
bash run.sh --stage 1 --stop-stage 1
```
Full arguments are listed below, you can check `run.sh` and `run_nst.sh` for more information about steps in each stage and their arguments. We used `num_split = 60` and generate shards with different cpu for the experiments in our paper which saved us lots of inference time and data shards generation time.
```bash
bash run.sh --stage 1 --stop-stage 1 --dir exp/conformer_test_fully_supervised --supervised_data_list data_aishell.list --enable_nst 0 --num_split 1 --unsupervised_data_list wenet_1khr.list --dir_split wenet_split_60_test/ --job_num 0 --hypo_name hypothesis_nst0.txt --label 1 --wav_dir data/train/wenet_1k_untar/ --cer_hypo_dir wenet_cer_hypo --cer_label_dir wenet_cer_label --label_file label.txt --cer_hypo_threshold 10 --speak_rate_threshold 0 --utter_time_file utter_time.json --untar_dir data/train/wenet_1khr_untar/ --tar_dir data/train/wenet_1khr_tar/ --out_data_list data/train/wenet_1khr.list
```
- `dir` contains the training parameters.
- `data_list` contains paths for the training data list.
- `supervised_data_list` contains paths for supervised data shards.
- `unsupervised_data_list`contains paths for unsupervised data shards which is used for inference.
- `dir_split` is the directory stores split unsupervised data for parallel computing.
- `out_data_list` is the pseudo label data list file path.
- `enable_nst` indicates whether we train with pseudo label and split data, for initial teacher we set it to 0.
- This recipe uses the default `num_split=1` while we strongly recommend use larger number to decrease the inference and shards generation time.
> **HINTS** If num_split is set to N larger than 1, you need to modify the script in step 4-8 in run_nst.sh to submit N tasks into your own clusters (such as slurm,ngc etc..).
> We strongly recommend to do so since inference and pseudo-data generation is time-consuming.
### Noisy student interations
After finishing the initial fully supervised baseline, we now have the mixed list contains both supervised and pseudo data which is `wenet_1khr_nst0.list`.
We will use it as the `data_list` in the training step and the `data_list` for next NST iteration will be generated.
Here is an example command:
```bash
bash run.sh --stage 2 --stop-stage 2 --iter_num 2
```
Here we add extra argument `iter_num` for number of NST iterations. Intermediate files are named with `iter_num` as a suffix.
Please check the `run.sh` and `run_nst.sh` scripts for more information about each stage and their arguments.
## Performance Record
### Supervised baseline and standard NST
* Non-streaming conformer model with attention rescoring decoder.
* Without filter strategy, first iteration
* Feature info: using FBANK feature, dither, cmvn, online speed perturb
* Training info: lr 0.002, batch size 32, 8 gpu, acc_grad 4, 240 epochs, dither 0.1
* Decoding info: ctc_weight 0.3, average_num 30
| Supervised | Unsupervised | Test CER |
|--------------------------|--------------|----------|
| AISHELL-1 Only | ---- | 4.85 |
| AISHELL-1+WenetSpeech | ---- | 3.54 |
| AISHELL-1+AISHELL-2 | ---- | 1.01 |
| AISHELL-1 (standard NST) | WenetSpeech | 5.52 |
### Supervised AISHELL-1 and unsupervised 1khr WenetSpeech
* Non-streaming conformer model with attention rescoring decoder.
* Feature info: using FBANK feature
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75
| # nst iteration | AISHELL-1 test CER | Pseudo CER| Filtered CER | Filtered hours |
|----------------|--------------------|-----------|--------------|----------------|
| 0 | 4.85 | 47.10 | 25.18 | 323 |
| 1 | 4.86 | 37.02 | 20.93 | 436 |
| 2 | 4.75 | 31.81 | 19.74 | 540 |
| 3 | 4.69 | 28.27 | 17.85 | 592 |
| 4 | 4.48 | 26.64 | 14.76 | 588 |
| 5 | 4.41 | 24.70 | 15.86 | 670 |
| 6 | 4.34 | 23.64 | 15.40 | 669 |
| 7 | 4.31 | 23.79 | 15.75 | 694 |
### Supervised AISHELL-2 and unsupervised 4khr WenetSpeech
* Non-streaming conformer model with attention rescoring decoder.
* Feature info: using FBANK feature
* Training info: lr=0.002, batch_size=32, 8 GPUs, acc_grad=4, 120 epochs, dither=0.1
* Decoding info: ctc_weight=0.3, average_num=30, pseudo_ratio=0.75
| # nst iteration | AISHELL-2 test CER | Pseudo CER | Filtered CER | Filtered hours |
|----------------|--------------------|------------|--------------|----------------|
| 0 | 5.48 | 30.10 | 11.73 | 1637 |
| 1 | 5.09 | 28.31 | 9.39 | 2016 |
| 2 | 4.88 | 25.38 | 9.99 | 2186 |
| 3 | 4.74 | 22.47 | 10.66 | 2528 |
| 4 | 4.73 | 22.23 | 10.43 | 2734 |
## Citations
``` bibtex
@article{chen2022NST,
title={Improving Noisy Student Training on Non-target Domain Data for Automatic Speech Recognition},
author={Chen, Yu and Wen, Ding and Lai, Junjie},
journal={arXiv preprint arXiv:2203.15455},
year={2022}
}
# network architecture
# encoder related
encoder: conformer
encoder_conf:
output_size: 256 # dimension of attention
attention_heads: 4
linear_units: 2048 # the number of units of position-wise feed forward
num_blocks: 12 # the number of encoder blocks
dropout_rate: 0.1
positional_dropout_rate: 0.1
attention_dropout_rate: 0.0
input_layer: conv2d # encoder input type, you can chose conv2d, conv2d6 and conv2d8
normalize_before: true
cnn_module_kernel: 15
use_cnn_module: True
activation_type: 'swish'
pos_enc_layer_type: 'rel_pos'
selfattention_layer_type: 'rel_selfattn'
# decoder related
decoder: transformer
decoder_conf:
attention_heads: 4
linear_units: 2048
num_blocks: 6
dropout_rate: 0.1
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
# hybrid CTC/attention
model_conf:
ctc_weight: 0.3
lsm_weight: 0.1 # label smoothing option
length_normalized_loss: false
dataset_conf:
filter_conf:
max_length: 1200
min_length: 0
token_max_length: 200
token_min_length: 1
resample_conf:
resample_rate: 16000
speed_perturb: true
fbank_conf:
num_mel_bins: 80
frame_shift: 10
frame_length: 25
dither: 0.1
spec_aug: true
spec_aug_conf:
num_t_mask: 2
num_f_mask: 2
max_t: 50
max_f: 10
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 500 # sort_size should be less than shuffle_size
batch_conf:
batch_type: 'static' # static or dynamic
batch_size: 16
grad_clip: 5
accum_grad: 4
max_epoch: 240
log_interval: 100
optim: adam
optim_conf:
lr: 0.002
scheduler: warmuplr # pytorch v1.1.0+ required
scheduler_conf:
warmup_steps: 25000
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import random
def get_args():
parser = argparse.ArgumentParser(description='generate data.list file ')
parser.add_argument('--tar_dir', help='path for tar file')
parser.add_argument('--supervised_data_list',
help='path for supervised data list')
parser.add_argument('--pseudo_data_ratio',
type=float,
help='ratio of pseudo data, '
'0 means none pseudo data, '
'1 means all using pseudo data.')
parser.add_argument('--out_data_list', help='output path for data list')
args = parser.parse_args()
return args
def main():
args = get_args()
target_dir = args.tar_dir
pseudo_data_list = os.listdir(target_dir)
output_file = args.out_data_list
pseudo_data_ratio = args.pseudo_data_ratio
supervised_path = args.supervised_data_list
with open(supervised_path, "r") as reader:
supervised_data_list = reader.readlines()
pseudo_len = len(pseudo_data_list)
supervised_len = len(supervised_data_list)
random.shuffle(pseudo_data_list)
random.shuffle(supervised_data_list)
cur_ratio = pseudo_len / (pseudo_len + supervised_len)
if cur_ratio < pseudo_data_ratio:
pseudo_to_super_datio = pseudo_data_ratio / (1 - pseudo_data_ratio)
supervised_len = int(pseudo_len / pseudo_to_super_datio)
elif cur_ratio > pseudo_data_ratio:
super_to_pseudo_datio = (1 - pseudo_data_ratio) / pseudo_data_ratio
pseudo_len = int(supervised_len / super_to_pseudo_datio)
for i in range(len(pseudo_data_list)):
pseudo_data_list[i] = target_dir + "/" + pseudo_data_list[i] + "\n"
fused_list = pseudo_data_list[:pseudo_len] + supervised_data_list[:supervised_len]
with open(output_file, "w") as writer:
for line in fused_list:
writer.write(line)
if __name__ == '__main__':
main()
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import tarfile
import time
import json
def get_args():
parser = argparse.ArgumentParser(description='generate filter pseudo label')
parser.add_argument('--dir_num', required=True, help='split directory number')
parser.add_argument('--cer_hypo_dir', required=True,
help='prefix for cer_hypo_dir')
parser.add_argument('--utter_time_file', required=True,
help='the json file that contains audio time infos ')
parser.add_argument('--cer_hypo_threshold', required=True, type=float,
help='the cer-hypo threshold used to filter')
parser.add_argument('--speak_rate_threshold', type=float,
help='the cer threshold we use to filter')
parser.add_argument('--dir', required=True, help='dir for the experiment ')
# output untar and tar
parser.add_argument('--untar_dir', required=True,
help='the output path, '
'eg: data/train/wenet_untar_cer_hypo_nst1/')
parser.add_argument('--tar_dir', required=True,
help='the tar file path, '
'eg: data/train/wenet_tar_cer_hypo_leq_10_nst1/')
parser.add_argument('--wav_dir', required=True,
help='dir to store wav files, '
'eg "data/train/wenet_1k_untar/"')
parser.add_argument('--start_tar_id', default=0 , type=int,
help='the initial tar id (for debugging)')
args = parser.parse_args()
return args
def make_tarfile(output_filename, source_dir):
with tarfile.open(output_filename, "w") as tar:
tar.add(source_dir, arcname=os.path.basename(source_dir))
def main():
args = get_args()
dir_num = args.dir_num
dir_name = args.dir
output_dir = args.untar_dir
cer_hypo_threshold = args.cer_hypo_threshold
speak_rate_threshold = args.speak_rate_threshold
utter_time_file = args.utter_time_file
tar_dir = args.tar_dir
wav_dir = args.wav_dir
start_tar_id = args.start_tar_id
os.makedirs(tar_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
cer_hypo_name = args.cer_hypo_dir
print("start tar id is", start_tar_id)
print("make dirs")
utter_time_enable = True
dataset = "wenet"
utter_time = {}
if utter_time_enable:
if dataset == "wenet":
print("wenet")
with open(utter_time_file, encoding='utf-8') as fh:
utter_time = json.load(fh)
if dataset == "aishell2":
aishell2_jason = utter_time_file
print("aishell2")
with open(aishell2_jason, "r", encoding="utf-8") as f:
for line in f:
data = json.loads(line)
data_audio = data["audio_filepath"]
t_id = data_audio.split("/")[-1].split(".")[0]
data_duration = data["duration"]
utter_time[t_id] = data_duration
print(time.time(), "start time ")
cer_dict = {}
print("dir_num = ", dir_num)
cer_hypo_path = dir_name + "/Hypo_LM_diff10/" + cer_hypo_name
cer_hypo_path = cer_hypo_path + "_" + dir_num + "/wer"
with open(cer_hypo_path, 'r', encoding="utf-8") as reader:
data = reader.readlines()
for i in range(len(data)):
line = data[i]
if line[:3] == 'utt':
wer_list = data[i + 1].split()
wer_pred_lm = float(wer_list[1])
n_hypo = int(wer_list[3].split("=")[1])
utt_list = line.split()
lab_list = data[i + 2].split()
rec_list = data[i + 3].split()
utt_id = utt_list[1]
pred_no_lm = "".join(lab_list[1:])
pred_lm = "".join(rec_list[1:])
prediction = "".join(lab_list[1:])
if utter_time_enable:
utt_time = utter_time[utt_id]
cer_dict[utt_id] = [pred_no_lm, pred_lm, wer_pred_lm,
utt_time, n_hypo, prediction]
else:
cer_dict[utt_id] = [pred_no_lm, pred_lm,
wer_pred_lm, -1, -1, prediction]
c = 0
cer_preds = []
uttr_len = []
speak_rates = []
num_lines = 0
data_filtered = []
for key, item in cer_dict.items():
cer_pred = item[2]
speak_rate = item[4] / item[3] # char per second
if cer_pred <= cer_hypo_threshold and speak_rate > speak_rate_threshold:
num_lines += 1
c += 1
cer_preds.append(cer_pred)
uttr_len.append(item[4])
speak_rates.append(speak_rate)
pred = item[1]
utt_id = key
filtered_line = [utt_id, pred]
data_filtered.append(filtered_line)
num_uttr = 1000
len_data = len(data_filtered)
print("total sentences after filter ")
cur_id = start_tar_id * 1000
end_id = cur_id + num_uttr
if cur_id < len_data < end_id:
end_id = len_data
tar_id = start_tar_id
not_exist = []
while end_id <= len_data:
tar_s = str(tar_id)
diff = 6 - len(tar_s)
for _ in range(diff):
tar_s = "0" + tar_s
out_put_dir = output_dir + "dir" + str(dir_num)
out_put_dir = out_put_dir + "_" + "tar" + tar_s + "/"
os.makedirs(out_put_dir, exist_ok=True)
for i in range(cur_id, end_id):
print("dir:", dir_num, ", " "tar: ", tar_id,
", ", "progress:", i / len_data)
t_id, utter = data_filtered[i]
output_path = out_put_dir + t_id + ".txt"
wav_path = wav_dir + t_id + ".wav"
print(wav_path)
wav_exist = os.path.exists(wav_path)
if wav_exist:
# update .txt
with open(output_path, "w", encoding="utf-8") as writer:
writer.write(utter)
# update .wav
os.system("cp" + " " + wav_path + " "
+ out_put_dir + t_id + ".wav")
else:
print(" wav does not exists ! ", wav_path)
not_exist.append(wav_path)
tar_file_name = tar_dir + "dir" + str(dir_num) + "_" + tar_s + ".tar"
# tar the dir
make_tarfile(tar_file_name, out_put_dir)
# update index
tar_id += 1
cur_id += num_uttr
end_id += num_uttr
if cur_id < len_data < end_id:
end_id = len_data
print("end, now removing untar files for saving storge space.")
print("rm -rf" + " " + out_put_dir[:-1])
os.system("rm -rf" + " " + out_put_dir[:-1])
print("remove done")
print("There are ", len(not_exist), "wav files not exist")
print(not_exist)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment