Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
name: Test import espnet
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
test_import:
runs-on: ${{ matrix.os }}
strategy:
max-parallel: 20
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
pytorch-version: [1.13.1]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install -qq -y libsndfile1-dev
python3 -m pip install --upgrade pip setuptools wheel
- name: Install espnet with the least requirement
env:
TH_VERSION: ${{ matrix.pytorch-version }}
run: |
python3 -m pip install -U numba
./tools/installers/install_torch.sh false ${TH_VERSION} CPU
./tools/installers/install_chainer.sh CPU
python3 setup.py bdist_wheel
python3 -m pip install dist/espnet-*.whl
# log
python3 -m pip freeze
- name: Import all modules (Try1)
run: |
python3 ./ci/test_import_all.py
- name: Install espnet with the full requirement
env:
TH_VERSION: ${{ matrix.pytorch-version }}
run: |
python3 -m pip install "$(ls dist/espnet-*.whl)[all]"
# log
python3 -m pip freeze
- name: Import all modules (Try2)
run: |
python3 ./ci/test_import_all.py
name: Windows
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
test_windows:
runs-on: Windows-latest
strategy:
matrix:
python-version: ["3.10"]
pytorch-version: [1.13.1]
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@master
- uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ hashFiles('**/setup.py') }}-${{ hashFiles('**/Makefile') }}
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: 'x64'
- name: install dependencies
run: |
choco install -y wget
- name: install espnet
env:
ESPNET_PYTHON_VERSION: ${{ matrix.python-version }}
TH_VERSION: ${{ matrix.pytorch-version }}
CHAINER_VERSION: 6.0.0
USE_CONDA: false
run: |
./ci/install.sh
# general
*~
*.pyc
\#*\#
.\#*
*DS_Store
out.txt
espnet.egg-info/
doc/_build
slurm-*.out
tmp*
.eggs/
.hypothesis/
.idea
.pytest_cache/
__pycache__/
check_autopep8
.coverage
htmlcov
coverage.xml*
bats-core/
shellcheck*
check_shellcheck*
test_spm.vocab
test_spm.model
.vscode*
*.vim
*.swp
*.nfs*
# recipe related
egs*/*/*/data*
egs*/*/*/db
egs*/*/*/downloads
egs*/*/*/dump
egs*/*/*/enhan
egs*/*/*/exp
egs*/*/*/fbank
egs*/*/*/mfcc
egs*/*/*/stft
egs*/*/*/tensorboard
egs*/*/*/wav*
egs*/*/*/score*
egs*/*/*/nltk*
egs*/*/*/.cache*
egs*/*/*/pretrained_models*
egs*/fisher_callhome_spanish/*/local/mapping*
# tools related
tools/chainer
tools/bin
tools/include
tools/lib
tools/lib64
tools/bats-core
tools/chainer_ctc/
tools/kaldi*
tools/activate_python.sh
tools/miniconda.sh
tools/moses/
tools/mwerSegmenter/
tools/nkf/
tools/venv/
tools/sentencepiece/
tools/swig/
tools/warp-ctc/
tools/warp-transducer/
tools/*.done
tools/PESQ*
tools/hts_engine_API*
tools/open_jtalk*
tools/pyopenjtalk*
tools/tdmelodic_openjtalk*
tools/s3prl
tools/sctk*
tools/sph2pipe*
tools/espeak-ng*
tools/MBROLA*
tools/festival*
tools/speech_tools*
tools/phonemizer*
tools/py3mmseg
tools/anaconda
tools/ice-g2p
tools/fairseq
tools/RawNet
tools/._*
tools/ice-g2p*
tools/fairseq*
tools/featbin*
tools/miniconda
pull_request_rules:
- name: automatic merge if label=auto-merge
conditions:
- "label=auto-merge"
- "check-success=test_centos7"
- "check-success=test_debian11"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.10.2, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.11.0, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.12.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.13.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.10.2, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.11.0, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.12.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.13.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.10.2, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.11.0, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.12.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.13.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.10, 1.13.1, false, 6.0.0)"
- "check-success=test_import (ubuntu-latest, 3.10, 1.13.1)"
- "check-success=check_kaldi_symlinks"
actions:
merge:
method: merge
- name: delete head branch after merged
conditions:
- merged
actions:
delete_head_branch: {}
- name: "add label=auto-merge for PR by mergify"
conditions:
- author=mergify[bot]
actions:
label:
add: ["auto-merge"]
- name: warn on conflicts
conditions:
- conflict
actions:
comment:
message: This pull request is now in conflict :(
label:
add: ["conflicts"]
- name: unlabel conflicts
conditions:
- -conflict
actions:
label:
remove: ["conflicts"]
- name: "auto add label=ESPnet1"
conditions:
- files~=^(espnet/|egs/)
actions:
label:
add: ["ESPnet1"]
- name: "auto add label=ESPnet2"
conditions:
- files~=^(espnet2/|egs2/)
actions:
label:
add: ["ESPnet2"]
- name: "auto add label=ASR"
conditions:
- files~=^(espnet*/asr|egs*/*/asr1)
actions:
label:
add: ["ASR"]
- name: "auto add label=TTS"
conditions:
- files~=^(espnet*/tts|egs*/*/tts1)
actions:
label:
add: ["TTS"]
- name: "auto add label=MT"
conditions:
- files~=^(espnet*/mt|egs*/*/mt1)
actions:
label:
add: ["MT"]
- name: "auto add label=LM"
conditions:
- files~=^(espnet*/lm)
actions:
label:
add: ["LM"]
- name: "auto add label=README"
conditions:
- files~=README.md
actions:
label:
add: ["README"]
- name: "auto add label=Documentation"
conditions:
- files~=^doc/
actions:
label:
add: ["Documentation"]
- name: "auto add label=CI"
conditions:
- files~=^(.circleci/|ci/|.github/|.travis.yml)
actions:
label:
add: ["CI"]
- name: "auto add label=Installation"
conditions:
- files~=^(tools/|setup.py)
actions:
label:
add: ["Installation"]
- name: "auto add label=mergify"
conditions:
- files~=^.mergify.yml
actions:
label:
add: ["mergify"]
- name: "auto add label=Docker"
conditions:
- files~=^docker/
actions:
label:
add: ["Docker"]
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- id: end-of-file-fixer
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- id: check-yaml
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- id: check-added-large-files
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|doc)
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|doc)
Requirement already satisfied: typeguard in /root/miniconda3/lib/python3.10/site-packages (2.13.3)
# How to contribute to ESPnet
## 1. What to contribute
If you are interested in contributing to ESPnet, your contributions will fall into three categories: major features, minor updates, and recipes.
### 1.1 Major features
If you want to ask or propose a new feature, please first open a new issue with the tag `Feature request`
or directly contact Shinji Watanabe <shinjiw@ieee.org> or other main developers. Each feature implementation
and design should be discussed and modified according to ongoing and future works.
You can find ongoing major development plans at https://github.com/espnet/espnet/milestones
or in https://github.com/espnet/espnet/issues (pinned issues)
### 1.2 Minor Updates (minor feature, bug-fix for an issue)
If you want to propose a minor feature, update an existing minor feature, or fix a bug, please first take a look at
the existing [issues](https://github.com/espnet/espnet/pulls) and/or [pull requests](https://github.com/espnet/espnet/pulls).
Pick an issue and comment on the task that you want to work on this feature.
If you need help or additional information to propose the feature, you can open a new issue with the tag `Discussion` and ask ESPnet members.
### 1.3 Recipes
ESPnet provides and maintains many example scripts, called `recipes`, that demonstrate how to
use the toolkit. The recipes for ESPnet1 are put under `egs` directory, while ESPnet2 ones are put under `egs2`.
Similar to Kaldi, each subdirectory of `egs` and `egs2` corresponds to a corpus that we have example scripts for.
#### 1.3.1 ESPnet1 recipes
ESPnet1 recipes (`egs/X`) follow the convention from [Kaldi](https://github.com/kaldi-asr/kaldi) and may rely on
several utilities available in Kaldi. As such, porting a new recipe from Kaldi to ESPnet is natural, and the user
may refer to [port-kaldi-recipe](https://github.com/espnet/espnet/wiki/How-to-port-the-Kaldi-recipe-to-the-ESPnet-recipe%3F)
and other existing recipes for new additions. For the Kaldi-style recipe architecture, please refer to
[Prepare-Kaldi-Style-Directory](https://kaldi-asr.org/doc/data_prep.html).
For each recipe, we ask you to report the following: experiments results and environnement, model information.
For reproducibility, a link to upload the pre-trained model may also be added. All this information should be written
in a markdown file called `RESULTS.md` and put at the recipe root. You can refer to
[tedlium2-example](https://github.com/espnet/espnet/blob/master/egs/tedlium2/asr1/RESULTS.md) for an example.
To generate `RESULTS.md` for a recipe, please follow the following instructions:
- Execute `~/espnet/utils/show_result.sh` at the recipe root (where `run.sh` is located).
You'll get your environment information and evaluation results for each experiment in a markdown format.
From here, you can copy or redirect text output to `RESULTS.md`.
- Execute `~/espnet/utils/pack_model.sh` at the recipe root to generate a packed ESPnet model called `model.tar.gz`
and output model information. Executing the utility script without argument will give you the expected arguments.
- Put the model information in `RESULTS.md` and model link if you're using a private web storage
- If you don't have private web storage, please contact Shinji Watanabe <shinjiw@ieee.org> to give you access to ESPnet storage.
#### 1.3.2 ESPnet2 recipes
ESPnet2's recipes correspond to `egs2`. ESPnet2 applies a new paradigm without dependencies of Kaldi's binaries, which makes it lighter and more generalized.
For ESPnet2, we do not recommend preparing the recipe's stages for each corpus but using the common pipelines we provided in `asr.sh`, `tts.sh`, and
`enh.sh`. For details of creating ESPnet2 recipes, please refer to [egs2-readme](https://github.com/espnet/espnet/blob/master/egs2/TEMPLATE/README.md).
The common pipeline of ESPnet2 recipes will take care of the `RESULTS.md` generation, model packing, and uploading. ESPnet2 models are maintained at Hugging Face and Zenodo (Deprecated).
You can also refer to the document in https://github.com/espnet/espnet_model_zoo
To upload your model, you need first (This is currently deprecated , uploading to Huggingface Hub is prefered) :
1. Sign up to Zenodo: https://zenodo.org/
2. Create access token: https://zenodo.org/account/settings/applications/tokens/new/
3. Set your environment: % export ACCESS_TOKEN="<your token>"
To port models from zenodo using Hugging Face hub,
1. Create a Hugging Face account - https://huggingface.co/
2. Request to be added to espnet organisation - https://huggingface.co/espnet
3. Go to `egs2/RECIPE/*` and run `./scripts/utils/upload_models_to_hub.sh "ZENODO_MODEL_NAME"`
To upload models using Huggingface-cli follow the following steps:
You can also refer to https://huggingface.co/docs/transformers/model_sharing
1. Create a Hugging Face account - https://huggingface.co/
2. Request to be added to espnet organisation - https://huggingface.co/espnet
3. Run huggingface-cli login (You can get the token request at this step under setting > Access Tokens > espnet token
4. `huggingface-cli repo create your-model-name --organization espnet`
5. `git clone https://huggingface.co/username/your-model-name` (clone this outside ESPNet to avoid issues as this a git repo)
6. `cd your-model-name`
7. `git lfs install`
8. copy contents from exp diretory of your recipe into this directory (Check other models of similar task under ESPNet to confirm your directory structure)
9. `git add . `
10. `git commit -m "Add model files"`
11. `git push`
12. Check if the inference demo on HF is running successfully to verify the upload
#### 1.3.3 Additional requirements for new recipe
- Common/shared files and directories such as `utils`, `steps`, `asr.sh`, etc, should be linked using
a symbolic link (e.g.: `ln -s <source-path> <target-path>`). Please refer to existing recipes if you're
unaware which files/directories are shared. Noted that in espnet2, some of them are automatically generated by https://github.com/espnet/espnet/blob/master/egs2/TEMPLATE/asr1/setup.sh.
- Default training and decoding configurations (i.e.: the default one in `run.sh`) should be named respectively `train.yaml`
and `decode.yaml` and put in `conf/`. Additional or variant configurations should be put in `conf/tuning/` and named accordingly
to its differences.
- If a recipe for a new corpus is proposed, you should add its name and information to:
https://github.com/espnet/espnet/blob/master/egs/README.md if it's a ESPnet1 recipe,
or https://github.com/espnet/espnet/blob/master/egs2/README.md + `db.sh` if it's a ESPnet2 recipe.
#### 1.3.4 Checklist before you submit the recipe-based PR
- [ ] be careful about the name for the recipe. It is recommended to follow naming conventions of the other recipes
- [ ] common/shared files are linked with **soft link** (see Section 1.3.3)
- [ ] modified or new python scripts should be passed through **latest** black formating (by using python package black). The command to be executed could be `black espnet espnet2 test utils setup.py egs*/*/*/local egs2/TEMPLATE/*/pyscripts tools/*.py ci/*.py`
- [ ] modified or new python scripts should be passed through **latest** isort formating (by using python package isort). The command to be executed could be `isort espnet espnet2 test utils setup.py egs*/*/*/local egs2/TEMPLATE/*/pyscripts tools/*.py ci/*.py`
- [ ] cluster settings should be set as **default** (e.g., cmd.sh conf/slurm.conf conf/queue.conf conf/pbs.conf)
- [ ] update `egs/README.md` or `egs2/README.md` with corresponding recipes
- [ ] add corresponding entry in `egs2/TEMPLATE/db.sh` for a new corpus
- [ ] try to **simplify** the model configurations. We recommend to have only the best configuration for the start of a recipe. Please also follow the default rule defined in Section 1.3.3
- [ ] large meta-information for a corpus should be maintained elsewhere other than in the recipe itself
- [ ] recommend to also include results and pre-trained model with the recipe
## 2 Pull Request
If your proposed feature or bugfix is ready, please open a Pull Request (PR) at https://github.com/espnet/espnet
or use the Pull Request button in your forked repo. If you're not familiar with the process, please refer to the following guides:
- http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
- https://help.github.com/articles/creating-a-pull-request/
## 3 Version policy and development branches
We basically develop in the `master` branch.
1. We will keep the first version digit `0` until we have some super major changes in the project organization level.
2. The second version digit will be updated when we have major updates, including new functions and refactoring, and
their related bug fix and recipe changes.
This version update will be done roughly every half year so far (but it depends on the development plan).
3. The third version digit will be updated when we fix serious bugs or accumulate some minor changes, including
recipe related changes periodically (every two months or so).
## 4 Unit testing
ESPnet's testing is located under `test/`. You can install additional packages for testing as follows:
``` console
$ cd <espnet_root>
$ . ./tools/activate_python.sh
$ pip install -e ".[test]"
```
### 4.1 Python
Then you can run the entire test suite using [flake8](http://flake8.pycqa.org/en/latest/), [autopep8](https://github.com/hhatto/autopep8), [black](https://github.com/psf/black), [isort](https://github.com/PyCQA/isort) and [pytest](https://docs.pytest.org/en/latest/) with [coverage](https://pytest-cov.readthedocs.io/en/latest/reporting.html) by
``` console
./ci/test_python.sh
```
Followings are some useful tips when you are using pytest:
- New test file should be put under `test/` directory and named `test_xxx.py`. Each method in the test file should
have the format `def test_yyy(...)`. [Pytest](https://docs.pytest.org/en/latest/) will automatically find and test them.
- We recommend adding several small test files instead of grouping them in one big file (e.g.: `test_e2e_xxx.py`).
Technically, a test file should only cover methods from one file (e.g.: `test_transformer_utils.py` to test `transformer_utils.py`).
- To monitor test coverage and avoid the overlapping test, we recommend using `pytest --cov-report term-missing <test_file|dir>`
to highlight covered and missed lines. For more details, please refer to [coverage-test](https://pytest-cov.readthedocs.io/en/latest/readme.html).
- We limited test running time to 2.0 seconds (see: [pytest-timeouts](https://pypi.org/project/pytest-timeouts/)). As such,
we recommend using small model parameters and avoiding dynamic imports, file access, and unnecessary loops. If a unit test needs
more running time, you can annotate your test with `@pytest.mark.execution_timeout(sec)`.
- For test initialization (parameters, modules, etc), you can use pytest fixtures. Refer to [pytest fixtures](https://docs.pytest.org/en/latest/fixture.html#using-fixtures-from-classes-modules-or-projects) for more information.
In addition, please follow the [PEP 8 convention](https://peps.python.org/pep-0008/) for the coding style and [Google's convention for docstrings](https://google.github.io/styleguide/pyguide.html#383-functions-and-methods).
Below are some specific points that should be taken care of in particular:
- [import ordering](https://peps.python.org/pep-0008/#imports)
- Avoid writing python2-style code. For example, `super().__init__()` is preferred over `super(CLASS_NAME, self).__init()__`.
### 4.2 Bash scripts
You can also test the scripts in `utils` with [bats-core](https://github.com/bats-core/bats-core) and [shellcheck](https://github.com/koalaman/shellcheck).
To test:
``` console
./ci/test_shell.sh
```
## 5 Integration testing
Write new integration tests in [ci/test_integration_espnet1.sh](ci/test_integration_espnet1.sh) or [ci/test_integration_espnet2.sh](ci/test_integration_espnet2.sh) when you add new features in [espnet/bin](espnet/bin) or [espnet2/bin](espnet2/bin), respectively. They use our smallest dataset [egs/mini_an4](egs/mini_an4) or [egs2/mini_an4](egs/mini_an4) to test `run.sh`. **Don't call `python` directly in integration tests. Instead, use `coverage run --append`** as a python interpreter. Especially, `run.sh` should support `--python ${python}` to call the custom interpreter.
```bash
# ci/test_integration_espnet{1,2}.sh
python="coverage run --append"
cd egs/mini_an4/your_task
./run.sh --python "${python}"
```
### 5.1 Configuration files
- [setup.cfg](setup.cfg) configures pytest, black and flake8.
- [.travis.yml](.travis.yml) configures Travis-CI (unittests, doc deploy).
- [.circleci/config.yml](.circleci/config.yml) configures Circle-CI (unittests, integration tests).
- [.github/workflows](.github/workflows/) configures Github Actions (unittests, integration tests).
- [codecov.yml](codecov.yml) configures CodeCov (code coverage).
## 6 Writing new tools
You can place your new tools under
- `espnet/bin`: heavy and large (e.g., neural network related) core tools.
- `utils`: lightweight self-contained python/bash scripts.
For `utils` scripts, do not forget to add help messages and test scripts under `test_utils`.
### 6.1 Python tools guideline
To generate doc, do not forget `def get_parser(): -> ArgumentParser` in the main file.
```python
#!/usr/bin/env python3
# Copyright XXX
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
# NOTE: do not forget this
def get_parser():
parser = argparse.ArgumentParser(
description="awsome tool", # DO NOT forget this
)
...
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
...
```
### 6.2 Bash tools guideline
To generate doc, support `--help` to show its usage. If you use Kaldi's `utils/parse_option.sh`, define `help_message="Usage: $0 ..."`.
## 7 Writing documentation
See [doc](doc/README.md).
## 8 Adding pretrained models
Pack your trained models using `utils/pack_model.sh` and upload it [here](https://drive.google.com/open?id=1k9RRyc06Zl0mM2A7mi-hxNiNMFb_YzTF) (You require permission).
Add the shared link to `utils/recog_wav.sh` or `utils/synth_wav.sh` as follows:
```sh
"tedlium.demo") share_url="https://drive.google.com/open?id=1UqIY6WJMZ4sxNxSugUqp3mrGb3j6h7xe" ;;
```
The model name is arbitrary for now.
## 9 On CI failure
### 9.1 Travis CI and Github Actions
1. read the log from PR checks > details
### 9.2 Circle CI
1. read the log from PR checks > details
2. turn on Rerun workflow > Rerun job with SSH
3. open your local terminal and `ssh -p xxx xxx` (check circle ci log for the exact address)
4. try anything you can to pass the CI
### 9.3 Codecov
1. write more tests to increase coverage
2. explain to reviewers why you can't increase coverage
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
<div align="left"><img src="doc/image/espnet_logo1.png" width="550"/></div>
# ESPnet: end-to-end speech processing toolkit
|system/pytorch ver.|1.10.2|1.11.0|1.12.1|1.13.1|
| :---: | :---: | :---: | :---: | :---: |
|ubuntu/python3.10/pip||||[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|
|ubuntu/python3.9/pip|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|
|ubuntu/python3.8/pip|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|
|ubuntu/python3.7/pip|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|[![Github Actions](https://github.com/espnet/espnet/workflows/CI/badge.svg)](https://github.com/espnet/espnet/actions)|
|debian11/python3.7/conda||||[![debian11](https://github.com/espnet/espnet/workflows/debian11/badge.svg)](https://github.com/espnet/espnet/actions?query=workflow%3Adebian11)|
|centos7/python3.7/conda||||[![centos7](https://github.com/espnet/espnet/workflows/centos7/badge.svg)](https://github.com/espnet/espnet/actions?query=workflow%3Acentos7)|
|ubuntu/doc/python3.8||||[![doc](https://github.com/espnet/espnet/workflows/doc/badge.svg)](https://github.com/espnet/espnet/actions?query=workflow%3Adoc)|
[![PyPI version](https://badge.fury.io/py/espnet.svg)](https://badge.fury.io/py/espnet)
[![Python Versions](https://img.shields.io/pypi/pyversions/espnet.svg)](https://pypi.org/project/espnet/)
[![Downloads](https://pepy.tech/badge/espnet)](https://pepy.tech/project/espnet)
[![GitHub license](https://img.shields.io/github/license/espnet/espnet.svg)](https://github.com/espnet/espnet)
[![codecov](https://codecov.io/gh/espnet/espnet/branch/master/graph/badge.svg)](https://codecov.io/gh/espnet/espnet)
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/)
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/espnet/espnet/master.svg)](https://results.pre-commit.ci/latest/github/espnet/espnet/master)
[![Mergify Status](https://img.shields.io/endpoint.svg?url=https://api.mergify.com/v1/badges/espnet/espnet&style=flat)](https://mergify.com)
[![Gitter](https://badges.gitter.im/espnet-en/community.svg)](https://gitter.im/espnet-en/community?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge)
[**Docs**](https://espnet.github.io/espnet/)
| [**Example**](https://github.com/espnet/espnet/tree/master/egs)
| [**Example (ESPnet2)**](https://github.com/espnet/espnet/tree/master/egs2)
| [**Docker**](https://github.com/espnet/espnet/tree/master/docker)
| [**Notebook**](https://github.com/espnet/notebook)
ESPnet is an end-to-end speech processing toolkit covering end-to-end speech recognition, text-to-speech, speech translation, speech enhancement, speaker diarization, spoken language understanding, and so on.
ESPnet uses [pytorch](http://pytorch.org/) as a deep learning engine and also follows [Kaldi](http://kaldi-asr.org/) style data processing, feature extraction/format, and recipes to provide a complete setup for various speech processing experiments.
## Tutorial Series
- 2019 Tutorial at Interspeech
- [Material](https://github.com/espnet/interspeech2019-tutorial)
- 2021 Tutorial at CMU
- [Online video](https://youtu.be/2mRz3wH1vd0)
- [Material](https://colab.research.google.com/github/espnet/notebook/blob/master/espnet2_tutorial_2021_CMU_11751_18781.ipynb)
- 2022 Tutorial at CMU
- Usage of ESPnet (ASR as an example)
- [Online video](https://youtu.be/YDN8cVjxSik)
- [Material](https://colab.research.google.com/github/espnet/notebook/blob/master/espnet2_recipe_tutorial_CMU_11751_18781_Fall2022.ipynb)
- Add new models/tasks to ESPnet
- [Online video](https://youtu.be/Css3XAes7SU)
- [Material](https://colab.research.google.com/github/espnet/notebook/blob/master/espnet2_new_task_tutorial_CMU_11751_18781_Fall2022.ipynb)
## Key Features
### Kaldi style complete recipe
- Support numbers of `ASR` recipes (WSJ, Switchboard, CHiME-4/5, Librispeech, TED, CSJ, AMI, HKUST, Voxforge, REVERB, etc.)
- Support numbers of `TTS` recipes with a similar manner to the ASR recipe (LJSpeech, LibriTTS, M-AILABS, etc.)
- Support numbers of `ST` recipes (Fisher-CallHome Spanish, Libri-trans, IWSLT'18, How2, Must-C, Mboshi-French, etc.)
- Support numbers of `MT` recipes (IWSLT'14, IWSLT'16, the above ST recipes etc.)
- Support numbers of `SLU` recipes (CATSLU-MAPS, FSC, Grabo, IEMOCAP, JDCINAL, SNIPS, SLURP, SWBD-DA, etc.)
- Support numbers of `SE/SS` recipes (DNS-IS2020, LibriMix, SMS-WSJ, VCTK-noisyreverb, WHAM!, WHAMR!, WSJ-2mix, etc.)
- Support voice conversion recipe (VCC2020 baseline)
- Support speaker diarization recipe (mini_librispeech, librimix)
- Support singing voice synthesis recipe (ofuton_p_utagoe_db)
### ASR: Automatic Speech Recognition
- **State-of-the-art performance** in several ASR benchmarks (comparable/superior to hybrid DNN/HMM and CTC)
- **Hybrid CTC/attention** based end-to-end ASR
- Fast/accurate training with CTC/attention multitask training
- CTC/attention joint decoding to boost monotonic alignment decoding
- Encoder: VGG-like CNN + BiRNN (LSTM/GRU), sub-sampling BiRNN (LSTM/GRU), Transformer, Conformer, [Branchformer](https://proceedings.mlr.press/v162/peng22a.html), or [E-Branchformer](https://arxiv.org/abs/2210.00077)
- Decoder: RNN (LSTM/GRU), Transformer, or S4
- Attention: Dot product, location-aware attention, variants of multi-head
- Incorporate RNNLM/LSTMLM/TransformerLM/N-gram trained only with text data
- Batch GPU decoding
- Data augmentation
- **Transducer** based end-to-end ASR
- Architecture:
- RNN-based encoder and decoder.
- Custom encoder and decoder supporting Transformer, Conformer (encoder), 1D Conv / TDNN (encoder) and causal 1D Conv (decoder) blocks.
- VGG2L (RNN/custom encoder) and Conv2D (custom encoder) bottlenecks.
- Search algorithms:
- Greedy search constrained to one emission by timestep.
- Default beam search algorithm [[Graves, 2012]](https://arxiv.org/abs/1211.3711) without prefix search.
- Alignment-Length Synchronous decoding [[Saon et al., 2020]](https://ieeexplore.ieee.org/abstract/document/9053040).
- Time Synchronous Decoding [[Saon et al., 2020]](https://ieeexplore.ieee.org/abstract/document/9053040).
- N-step Constrained beam search modified from [[Kim et al., 2020]](https://arxiv.org/abs/2002.03577).
- modified Adaptive Expansion Search based on [[Kim et al., 2021]](https://ieeexplore.ieee.org/abstract/document/9250505) and NSC.
- Features:
- Multi-task learning with various auxiliary losses:
- Encoder: CTC, auxiliary Transducer and symmetric KL divergence.
- Decoder: cross-entropy w/ label smoothing.
- Transfer learning with acoustic model and/or language model.
- Training with FastEmit regularization method [[Yu et al., 2021]](https://arxiv.org/abs/2010.11148).
> Please refer to the [tutorial page](https://espnet.github.io/espnet/tutorial.html#transducer) for complete documentation.
- CTC segmentation
- Non-autoregressive model based on Mask-CTC
- ASR examples for supporting endangered language documentation (Please refer to egs/puebla_nahuatl and egs/yoloxochitl_mixtec for details)
- Wav2Vec2.0 pretrained model as Encoder, imported from [FairSeq](https://github.com/pytorch/fairseq/tree/master/fairseq).
- Self-supervised learning representations as features, using upstream models in [S3PRL](https://github.com/s3prl/s3prl) in frontend.
- Set `frontend` to be `s3prl`
- Select any upstream model by setting the `frontend_conf` to the corresponding name.
- Transfer Learning :
- easy usage and transfers from models previously trained by your group, or models from [ESPnet Hugging Face repository](https://huggingface.co/espnet).
- [Documentation](https://github.com/espnet/espnet/tree/master/egs2/mini_an4/asr1/transfer_learning.md) and [toy example runnable on colab](https://github.com/espnet/notebook/blob/master/espnet2_asr_transfer_learning_demo.ipynb).
- Streaming Transformer/Conformer ASR with blockwise synchronous beam search.
- Restricted Self-Attention based on [Longformer](https://arxiv.org/abs/2004.05150) as an encoder for long sequences
- OpenAI [Whisper](https://openai.com/blog/whisper/) model, robust ASR based on large-scale, weakly-supervised multitask learning
Demonstration
- Real-time ASR demo with ESPnet2 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/espnet/notebook/blob/master/espnet2_asr_realtime_demo.ipynb)
- [Gradio](https://github.com/gradio-app/gradio) Web Demo on [Hugging Face Spaces](https://huggingface.co/docs/hub/spaces). Check out the [Web Demo](https://huggingface.co/spaces/akhaliq/espnet2_asr)
- Streaming Transformer ASR [Local Demo](https://github.com/espnet/notebook/blob/master/espnet2_streaming_asr_demo.ipynb) with ESPnet2.
### TTS: Text-to-speech
- Architecture
- Tacotron2
- Transformer-TTS
- FastSpeech
- FastSpeech2
- Conformer FastSpeech & FastSpeech2
- VITS
- JETS
- Multi-speaker & multi-language extention
- Pretrained speaker embedding (e.g., X-vector)
- Speaker ID embedding
- Language ID embedding
- Global style token (GST) embedding
- Mix of the above embeddings
- End-to-end training
- End-to-end text-to-wav model (e.g., VITS, JETS, etc.)
- Joint training of text2mel and vocoder
- Various language support
- En / Jp / Zn / De / Ru / And more...
- Integration with neural vocoders
- Parallel WaveGAN
- MelGAN
- Multi-band MelGAN
- HiFiGAN
- StyleMelGAN
- Mix of the above models
Demonstration
- Real-time TTS demo with ESPnet2 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/espnet/notebook/blob/master/espnet2_tts_realtime_demo.ipynb)
- Integrated to [Hugging Face Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See demo: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/akhaliq/ESPnet2-TTS)
To train the neural vocoder, please check the following repositories:
- [kan-bayashi/ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN)
- [r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder)
> **NOTE**:
> - We are moving on ESPnet2-based development for TTS.
> - The use of ESPnet1-TTS is deprecated, please use [ESPnet2-TTS](https://github.com/espnet/espnet/tree/master/egs2/TEMPLATE/tts1).
### SE: Speech enhancement (and separation)
- Single-speaker speech enhancement
- Multi-speaker speech separation
- Unified encoder-separator-decoder structure for time-domain and frequency-domain models
- Encoder/Decoder: STFT/iSTFT, Convolution/Transposed-Convolution
- Separators: BLSTM, Transformer, Conformer, [TasNet](https://arxiv.org/abs/1809.07454), [DPRNN](https://arxiv.org/abs/1910.06379), [SkiM](https://arxiv.org/abs/2201.10800), [SVoice](https://arxiv.org/abs/2011.02329), [DC-CRN](https://web.cse.ohio-state.edu/~wang.77/papers/TZW.taslp21.pdf), [DCCRN](https://arxiv.org/abs/2008.00264), [Deep Clustering](https://ieeexplore.ieee.org/document/7471631), [Deep Attractor Network](https://pubmed.ncbi.nlm.nih.gov/29430212/), [FaSNet](https://arxiv.org/abs/1909.13387), [iFaSNet](https://arxiv.org/abs/1910.14104), Neural Beamformers, etc.
- Flexible ASR integration: working as an individual task or as the ASR frontend
- Easy to import pretrained models from [Asteroid](https://github.com/asteroid-team/asteroid)
- Both the pre-trained models from Asteroid and the specific configuration are supported.
Demonstration
- Interactive SE demo with ESPnet2 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fjRJCh96SoYLZPRxsjF9VDv4Q2VoIckI?usp=sharing)
### ST: Speech Translation & MT: Machine Translation
- **State-of-the-art performance** in several ST benchmarks (comparable/superior to cascaded ASR and MT)
- Transformer based end-to-end ST (new!)
- Transformer based end-to-end MT (new!)
### VC: Voice conversion
- Transformer and Tacotron2 based parallel VC using melspectrogram (new!)
- End-to-end VC based on cascaded ASR+TTS (Baseline system for Voice Conversion Challenge 2020!)
### SLU: Spoken Language Understanding
- Architecture
- Transformer based Encoder
- Conformer based Encoder
- [Branchformer](https://proceedings.mlr.press/v162/peng22a.html) based Encoder
- [E-Branchformer](https://arxiv.org/abs/2210.00077) based Encoder
- RNN based Decoder
- Transformer based Decoder
- Support Multitasking with ASR
- Predict both intent and ASR transcript
- Support Multitasking with NLU
- Deliberation encoder based 2 pass model
- Support using pretrained ASR models
- Hubert
- Wav2vec2
- VQ-APC
- TERA and more ...
- Support using pretrained NLP models
- BERT
- MPNet And more...
- Various language support
- En / Jp / Zn / Nl / And more...
- Supports using context from previous utterances
- Supports using other tasks like SE in pipeline manner
- Supports Two Pass SLU that combines audio and ASR transcript
Demonstration
- Performing noisy spoken language understanding using speech enhancement model followed by spoken language understanding model. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/14nCrJ05vJcQX0cJuXjbMVFWUHJ3Wfb6N?usp=sharing)
- Performing two pass spoken language understanding where the second pass model attends on both acoustic and semantic information. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1p2cbGIPpIIcynuDl4ZVHDpmNPl8Nh_ci?usp=sharing)
- Integrated to [Hugging Face Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See SLU demo on multiple languages: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Siddhant/ESPnet2-SLU)
### SUM: Speech Summarization
- End to End Speech Summarization Recipe for Instructional Videos using Restricted Self-Attention [[Sharma et al., 2022]](https://arxiv.org/abs/2110.06263)
### SVS: Singing Voice Synthesis
- Framework merge from [Muskits](https://github.com/SJTMusicTeam/Muskits)
- Architecture
- RNN-based non-autoregressive model
- Xiaoice
- Sequence-to-sequence Transformer (with GLU-based encoder)
- MLP singer (in progress)
- Tacotron-singing (in progress)
- DiffSinger (in progress)
- VISinger
- Support multi-speaker & multilingual singing synthesis
- Speaker ID embedding
- Language ID embedding
- Various language support
- Jp / En / Kr / Zh
- Tight integration with neural vocoders (the same as TTS)
### SSL: Self-supervised Learning
- Support HuBERT Pretraining:
* Example recipe: [egs2/LibriSpeech/ssl1](egs2/LibriSpeech/ssl1)
### UASR: Unsupervised ASR (EURO: ESPnet Unsupervised Recognition - Open-source)
- Architecture
- wav2vec-U (with different self-supervised models)
- wav2vec-U 2.0 (in progress)
- Support PrefixBeamSearch and K2-based WFST decoding
### DNN Framework
- Flexible network architecture thanks to chainer and pytorch
- Flexible front-end processing thanks to [kaldiio](https://github.com/nttcslab-sp/kaldiio) and HDF5 support
- Tensorboard based monitoring
### ESPnet2
See [ESPnet2](https://espnet.github.io/espnet/espnet2_tutorial.html).
- Independent from Kaldi/Chainer, unlike ESPnet1
- On the fly feature extraction and text processing when training
- Supporting DistributedDataParallel and DaraParallel both
- Supporting multiple nodes training and integrated with [Slurm](https://slurm.schedmd.com/) or MPI
- Supporting Sharded Training provided by [fairscale](https://github.com/facebookresearch/fairscale)
- A template recipe which can be applied for all corpora
- Possible to train any size of corpus without CPU memory error
- [ESPnet Model Zoo](https://github.com/espnet/espnet_model_zoo)
- Integrated with [wandb](https://espnet.github.io/espnet/espnet2_training_option.html#weights-biases-integration)
## Installation
- If you intend to do full experiments including DNN training, then see [Installation](https://espnet.github.io/espnet/installation.html).
- If you just need the Python module only:
```sh
# We recommend you installing pytorch before installing espnet following https://pytorch.org/get-started/locally/
pip install espnet
# To install latest
# pip install git+https://github.com/espnet/espnet
# To install additional packages
# pip install "espnet[all]"
```
If you'll use ESPnet1, please install chainer and cupy.
```sh
pip install chainer==6.0.0 cupy==6.0.0 # [Option]
```
You might need to install some packages depending on each task. We prepared various installation scripts at [tools/installers](tools/installers).
- (ESPnet2) Once installed, run `wandb login` and set `--use_wandb true` to enable tracking runs using W&B.
## Usage
See [Usage](https://espnet.github.io/espnet/tutorial.html).
## Docker Container
go to [docker/](docker/) and follow [instructions](https://espnet.github.io/espnet/docker.html).
## Contribution
Thank you for taking times for ESPnet! Any contributions to ESPnet are welcome and feel free to ask any questions or requests to [issues](https://github.com/espnet/espnet/issues).
If it's the first contribution to ESPnet for you, please follow the [contribution guide](CONTRIBUTING.md).
## Results and demo
You can find useful tutorials and demos in [Interspeech 2019 Tutorial](https://github.com/espnet/interspeech2019-tutorial)
### ASR results
<details><summary>expand</summary><div>
We list the character error rate (CER) and word error rate (WER) of major ASR tasks.
| Task | CER (%) | WER (%) | Pretrained model |
| ----------------------------------------------------------------- | :-------------: | :-------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Aishell dev/test | 4.6/5.1 | N/A | [link](https://github.com/espnet/espnet/blob/master/egs/aishell/asr1/RESULTS.md#conformer-kernel-size--15--specaugment--lm-weight--00-result) |
| **ESPnet2** Aishell dev/test | 4.1/4.4 | N/A | [link](https://github.com/espnet/espnet/tree/master/egs2/aishell/asr1#branchformer-initial) |
| Common Voice dev/test | 1.7/1.8 | 2.2/2.3 | [link](https://github.com/espnet/espnet/blob/master/egs/commonvoice/asr1/RESULTS.md#first-results-default-pytorch-transformer-setting-with-bpe-100-epochs-single-gpu) |
| CSJ eval1/eval2/eval3 | 5.7/3.8/4.2 | N/A | [link](https://github.com/espnet/espnet/blob/master/egs/csj/asr1/RESULTS.md#pytorch-backend-transformer-without-any-hyperparameter-tuning) |
| **ESPnet2** CSJ eval1/eval2/eval3 | 4.5/3.3/3.6 | N/A | [link](https://github.com/espnet/espnet/tree/master/egs2/csj/asr1#initial-conformer-results) |
| **ESPnet2** GigaSpeech dev/test | N/A | 10.6/10.5 | [link](https://github.com/espnet/espnet/tree/master/egs2/gigaspeech/asr1#e-branchformer) |
| HKUST dev | 23.5 | N/A | [link](https://github.com/espnet/espnet/blob/master/egs/hkust/asr1/RESULTS.md#transformer-only-20-epochs) |
| **ESPnet2** HKUST dev | 21.2 | N/A | [link](https://github.com/espnet/espnet/tree/master/egs2/hkust/asr1#transformer-asr--transformer-lm) |
| Librispeech dev_clean/dev_other/test_clean/test_other | N/A | 1.9/4.9/2.1/4.9 | [link](https://github.com/espnet/espnet/blob/master/egs/librispeech/asr1/RESULTS.md#pytorch-large-conformer-with-specaug--speed-perturbation-8-gpus--transformer-lm-4-gpus) |
| **ESPnet2** Librispeech dev_clean/dev_other/test_clean/test_other | 0.6/1.5/0.6/1.4 | 1.7/3.4/1.8/3.6 | [link](https://github.com/espnet/espnet/tree/master/egs2/librispeech/asr1#self-supervised-learning-features-hubert_large_ll60k-conformer-utt_mvn-with-transformer-lm) |
| Switchboard (eval2000) callhm/swbd | N/A | 14.0/6.8 | [link](https://github.com/espnet/espnet/blob/master/egs/swbd/asr1/RESULTS.md#conformer-with-bpe-2000-specaug-speed-perturbation-transformer-lm-decoding) |
| **ESPnet2** Switchboard (eval2000) callhm/swbd | N/A | 13.4/7.3 | [link](https://github.com/espnet/espnet/tree/master/egs2/swbd/asr1#e-branchformer) |
| TEDLIUM2 dev/test | N/A | 8.6/7.2 | [link](https://github.com/espnet/espnet/blob/master/egs/tedlium2/asr1/RESULTS.md#conformer-large-model--specaug--speed-perturbation--rnnlm) |
| **ESPnet2** TEDLIUM2 dev/test | N/A | 7.3/7.1 | [link](https://github.com/espnet/espnet/blob/master/egs2/tedlium2/asr1/README.md#e-branchformer-12-encoder-layers) |
| TEDLIUM3 dev/test | N/A | 9.6/7.6 | [link](https://github.com/espnet/espnet/blob/master/egs/tedlium3/asr1/RESULTS.md) |
| WSJ dev93/eval92 | 3.2/2.1 | 7.0/4.7 | N/A |
| **ESPnet2** WSJ dev93/eval92 | 1.1/0.8 | 2.8/1.8 | [link](https://github.com/espnet/espnet/tree/master/egs2/wsj/asr1#self-supervised-learning-features-wav2vec2_large_ll60k-conformer-utt_mvn-with-transformer-lm) |
Note that the performance of the CSJ, HKUST, and Librispeech tasks was significantly improved by using the wide network (#units = 1024) and large subword units if necessary reported by [RWTH](https://arxiv.org/pdf/1805.03294.pdf).
If you want to check the results of the other recipes, please check `egs/<name_of_recipe>/asr1/RESULTS.md`.
</div></details>
### ASR demo
<details><summary>expand</summary><div>
You can recognize speech in a WAV file using pretrained models.
Go to a recipe directory and run `utils/recog_wav.sh` as follows:
```sh
# go to recipe directory and source path of espnet tools
cd egs/tedlium2/asr1 && . ./path.sh
# let's recognize speech!
recog_wav.sh --models tedlium2.transformer.v1 example.wav
```
where `example.wav` is a WAV file to be recognized.
The sampling rate must be consistent with that of data used in training.
Available pretrained models in the demo script are listed as below.
| Model | Notes |
| :----------------------------------------------------------------------------------------------- | :--------------------------------------------------------- |
| [tedlium2.rnn.v1](https://drive.google.com/open?id=1UqIY6WJMZ4sxNxSugUqp3mrGb3j6h7xe) | Streaming decoding based on CTC-based VAD |
| [tedlium2.rnn.v2](https://drive.google.com/open?id=1cac5Uc09lJrCYfWkLQsF8eapQcxZnYdf) | Streaming decoding based on CTC-based VAD (batch decoding) |
| [tedlium2.transformer.v1](https://drive.google.com/open?id=1cVeSOYY1twOfL9Gns7Z3ZDnkrJqNwPow) | Joint-CTC attention Transformer trained on Tedlium 2 |
| [tedlium3.transformer.v1](https://drive.google.com/open?id=1zcPglHAKILwVgfACoMWWERiyIquzSYuU) | Joint-CTC attention Transformer trained on Tedlium 3 |
| [librispeech.transformer.v1](https://drive.google.com/open?id=1BtQvAnsFvVi-dp_qsaFP7n4A_5cwnlR6) | Joint-CTC attention Transformer trained on Librispeech |
| [commonvoice.transformer.v1](https://drive.google.com/open?id=1tWccl6aYU67kbtkm8jv5H6xayqg1rzjh) | Joint-CTC attention Transformer trained on CommonVoice |
| [csj.transformer.v1](https://drive.google.com/open?id=120nUQcSsKeY5dpyMWw_kI33ooMRGT2uF) | Joint-CTC attention Transformer trained on CSJ |
| [csj.rnn.v1](https://drive.google.com/open?id=1ALvD4nHan9VDJlYJwNurVr7H7OV0j2X9) | Joint-CTC attention VGGBLSTM trained on CSJ |
</div></details>
### SE results
<details><summary>expand</summary><div>
We list results from three different models on WSJ0-2mix, which is one the most widely used benchmark dataset for speech separation.
| Model | STOI | SAR | SDR | SIR |
| ------------------------------------------------- | ---- | ----- | ----- | ----- |
| [TF Masking](https://zenodo.org/record/4498554) | 0.89 | 11.40 | 10.24 | 18.04 |
| [Conv-Tasnet](https://zenodo.org/record/4498562) | 0.95 | 16.62 | 15.94 | 25.90 |
| [DPRNN-Tasnet](https://zenodo.org/record/4688000) | 0.96 | 18.82 | 18.29 | 28.92 |
</div></details>
### SE demos
<details><summary>expand</summary><div>
You can try the interactive demo with Google Colab. Please click the following button to get access to the demos.
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1fjRJCh96SoYLZPRxsjF9VDv4Q2VoIckI?usp=sharing)
It is based on ESPnet2. Pretrained models are available for both speech enhancement and speech separation tasks.
</div></details>
### ST results
<details><summary>expand</summary><div>
We list 4-gram BLEU of major ST tasks.
#### end-to-end system
| Task | BLEU | Pretrained model |
| ------------------------------------------------- | :---: | :-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Fisher-CallHome Spanish fisher_test (Es->En) | 51.03 | [link](https://github.com/espnet/espnet/blob/master/egs/fisher_callhome_spanish/st1/RESULTS.md#train_spen_lcrm_pytorch_train_pytorch_transformer_bpe_short_long_bpe1000_specaug_asrtrans_mttrans) |
| Fisher-CallHome Spanish callhome_evltest (Es->En) | 20.44 | [link](https://github.com/espnet/espnet/blob/master/egs/fisher_callhome_spanish/st1/RESULTS.md#train_spen_lcrm_pytorch_train_pytorch_transformer_bpe_short_long_bpe1000_specaug_asrtrans_mttrans) |
| Libri-trans test (En->Fr) | 16.70 | [link](https://github.com/espnet/espnet/blob/master/egs/libri_trans/st1/RESULTS.md#train_spfr_lc_pytorch_train_pytorch_transformer_bpe_short_long_bpe1000_specaug_asrtrans_mttrans-1) |
| How2 dev5 (En->Pt) | 45.68 | [link](https://github.com/espnet/espnet/blob/master/egs/how2/st1/RESULTS.md#trainpt_tc_pytorch_train_pytorch_transformer_short_long_bpe8000_specaug_asrtrans_mttrans-1) |
| Must-C tst-COMMON (En->De) | 22.91 | [link](https://github.com/espnet/espnet/blob/master/egs/must_c/st1/RESULTS.md#train_spen-dede_tc_pytorch_train_pytorch_transformer_short_long_bpe8000_specaug_asrtrans_mttrans) |
| Mboshi-French dev (Fr->Mboshi) | 6.18 | N/A |
#### cascaded system
| Task | BLEU | Pretrained model |
| ------------------------------------------------- | :---: | :--------------: |
| Fisher-CallHome Spanish fisher_test (Es->En) | 42.16 | N/A |
| Fisher-CallHome Spanish callhome_evltest (Es->En) | 19.82 | N/A |
| Libri-trans test (En->Fr) | 16.96 | N/A |
| How2 dev5 (En->Pt) | 44.90 | N/A |
| Must-C tst-COMMON (En->De) | 23.65 | N/A |
If you want to check the results of the other recipes, please check `egs/<name_of_recipe>/st1/RESULTS.md`.
</div></details>
### ST demo
<details><summary>expand</summary><div>
(**New!**) We made a new real-time E2E-ST + TTS demonstration in Google Colab.
Please access the notebook from the following button and enjoy the real-time speech-to-speech translation!
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/espnet/notebook/blob/master/st_demo.ipynb)
---
You can translate speech in a WAV file using pretrained models.
Go to a recipe directory and run `utils/translate_wav.sh` as follows:
```sh
# go to recipe directory and source path of espnet tools
cd egs/fisher_callhome_spanish/st1 && . ./path.sh
# download example wav file
wget -O - https://github.com/espnet/espnet/files/4100928/test.wav.tar.gz | tar zxvf -
# let's translate speech!
translate_wav.sh --models fisher_callhome_spanish.transformer.v1.es-en test.wav
```
where `test.wav` is a WAV file to be translated.
The sampling rate must be consistent with that of data used in training.
Available pretrained models in the demo script are listed as below.
| Model | Notes |
| :----------------------------------------------------------------------------------------------------------- | :------------------------------------------------------- |
| [fisher_callhome_spanish.transformer.v1](https://drive.google.com/open?id=1hawp5ZLw4_SIHIT3edglxbKIIkPVe8n3) | Transformer-ST trained on Fisher-CallHome Spanish Es->En |
</div></details>
### MT results
<details><summary>expand</summary><div>
| Task | BLEU | Pretrained model |
| ------------------------------------------------- | :---: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Fisher-CallHome Spanish fisher_test (Es->En) | 61.45 | [link](https://github.com/espnet/espnet/blob/master/egs/fisher_callhome_spanish/mt1/RESULTS.md#trainen_lcrm_lcrm_pytorch_train_pytorch_transformer_bpe_bpe1000) |
| Fisher-CallHome Spanish callhome_evltest (Es->En) | 29.86 | [link](https://github.com/espnet/espnet/blob/master/egs/fisher_callhome_spanish/mt1/RESULTS.md#trainen_lcrm_lcrm_pytorch_train_pytorch_transformer_bpe_bpe1000) |
| Libri-trans test (En->Fr) | 18.09 | [link](https://github.com/espnet/espnet/blob/master/egs/libri_trans/mt1/RESULTS.md#trainfr_lcrm_tc_pytorch_train_pytorch_transformer_bpe1000) |
| How2 dev5 (En->Pt) | 58.61 | [link](https://github.com/espnet/espnet/blob/master/egs/how2/mt1/RESULTS.md#trainpt_tc_tc_pytorch_train_pytorch_transformer_bpe8000) |
| Must-C tst-COMMON (En->De) | 27.63 | [link](https://github.com/espnet/espnet/blob/master/egs/must_c/mt1/RESULTS.md#summary-4-gram-bleu) |
| IWSLT'14 test2014 (En->De) | 24.70 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
| IWSLT'14 test2014 (De->En) | 29.22 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
| IWSLT'14 test2014 (De->En) | 32.2 | [link](https://github.com/espnet/espnet/blob/master/egs2/iwslt14/mt1/README.md) |
| IWSLT'16 test2014 (En->De) | 24.05 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
| IWSLT'16 test2014 (De->En) | 29.13 | [link](https://github.com/espnet/espnet/blob/master/egs/iwslt16/mt1/RESULTS.md#result) |
</div></details>
### TTS results
<details><summary>ESPnet2</summary><div>
You can listen to the generated samples in the following URL.
- [ESPnet2 TTS generated samples](https://drive.google.com/drive/folders/1H3fnlBbWMEkQUfrHqosKN_ZX_WjO29ma?usp=sharing)
> Note that in the generation we use Griffin-Lim (`wav/`) and [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) (`wav_pwg/`).
You can download pretrained models via `espnet_model_zoo`.
- [ESPnet model zoo](https://github.com/espnet/espnet_model_zoo)
- [Pretrained model list](https://github.com/espnet/espnet_model_zoo/blob/master/espnet_model_zoo/table.csv)
You can download pretrained vocoders via `kan-bayashi/ParallelWaveGAN`.
- [kan-bayashi/ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN)
- [Pretrained vocoder list](https://github.com/kan-bayashi/ParallelWaveGAN#results)
</div></details>
<details><summary>ESPnet1</summary><div>
> NOTE: We are moving on ESPnet2-based development for TTS. Please check the latest results in the above ESPnet2 results.
You can listen to our samples in demo HP [espnet-tts-sample](https://espnet.github.io/espnet-tts-sample/).
Here we list some notable ones:
- [Single English speaker Tacotron2](https://drive.google.com/open?id=18JgsOCWiP_JkhONasTplnHS7yaF_konr)
- [Single Japanese speaker Tacotron2](https://drive.google.com/open?id=1fEgS4-K4dtgVxwI4Pr7uOA1h4PE-zN7f)
- [Single other language speaker Tacotron2](https://drive.google.com/open?id=1q_66kyxVZGU99g8Xb5a0Q8yZ1YVm2tN0)
- [Multi English speaker Tacotron2](https://drive.google.com/open?id=18S_B8Ogogij34rIfJOeNF8D--uG7amz2)
- [Single English speaker Transformer](https://drive.google.com/open?id=14EboYVsMVcAq__dFP1p6lyoZtdobIL1X)
- [Single English speaker FastSpeech](https://drive.google.com/open?id=1PSxs1VauIndwi8d5hJmZlppGRVu2zuy5)
- [Multi English speaker Transformer](https://drive.google.com/open?id=1_vrdqjM43DdN1Qz7HJkvMQ6lCMmWLeGp)
- [Single Italian speaker FastSpeech](https://drive.google.com/open?id=13I5V2w7deYFX4DlVk1-0JfaXmUR2rNOv)
- [Single Mandarin speaker Transformer](https://drive.google.com/open?id=1mEnZfBKqA4eT6Bn0eRZuP6lNzL-IL3VD)
- [Single Mandarin speaker FastSpeech](https://drive.google.com/open?id=1Ol_048Tuy6BgvYm1RpjhOX4HfhUeBqdK)
- [Multi Japanese speaker Transformer](https://drive.google.com/open?id=1fFMQDF6NV5Ysz48QLFYE8fEvbAxCsMBw)
- [Single English speaker models with Parallel WaveGAN](https://drive.google.com/open?id=1HvB0_LDf1PVinJdehiuCt5gWmXGguqtx)
- [Single English speaker knowledge distillation-based FastSpeech](https://drive.google.com/open?id=1wG-Y0itVYalxuLAHdkAHO7w1CWFfRPF4)
You can download all of the pretrained models and generated samples:
- [All of the pretrained E2E-TTS models](https://drive.google.com/open?id=1k9RRyc06Zl0mM2A7mi-hxNiNMFb_YzTF)
- [All of the generated samples](https://drive.google.com/open?id=1bQGuqH92xuxOX__reWLP4-cif0cbpMLX)
Note that in the generated samples we use the following vocoders: Griffin-Lim (**GL**), WaveNet vocoder (**WaveNet**), Parallel WaveGAN (**ParallelWaveGAN**), and MelGAN (**MelGAN**).
The neural vocoders are based on following repositories.
- [kan-bayashi/ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN): Parallel WaveGAN / MelGAN / Multi-band MelGAN
- [r9y9/wavenet_vocoder](https://github.com/r9y9/wavenet_vocoder): 16 bit mixture of Logistics WaveNet vocoder
- [kan-bayashi/PytorchWaveNetVocoder](https://github.com/kan-bayashi/PytorchWaveNetVocoder): 8 bit Softmax WaveNet Vocoder with the noise shaping
If you want to build your own neural vocoder, please check the above repositories.
[kan-bayashi/ParallelWaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) provides [the manual](https://github.com/kan-bayashi/ParallelWaveGAN#decoding-with-espnet-tts-models-features) about how to decode ESPnet-TTS model's features with neural vocoders. Please check it.
Here we list all of the pretrained neural vocoders. Please download and enjoy the generation of high quality speech!
| Model link | Lang | Fs [Hz] | Mel range [Hz] | FFT / Shift / Win [pt] | Model type |
| :--------------------------------------------------------------------------------------------------- | :---: | :-----: | :------------: | :--------------------: | :---------------------------------------------------------------------- |
| [ljspeech.wavenet.softmax.ns.v1](https://drive.google.com/open?id=1eA1VcRS9jzFa-DovyTgJLQ_jmwOLIi8L) | EN | 22.05k | None | 1024 / 256 / None | [Softmax WaveNet](https://github.com/kan-bayashi/PytorchWaveNetVocoder) |
| [ljspeech.wavenet.mol.v1](https://drive.google.com/open?id=1sY7gEUg39QaO1szuN62-Llst9TrFno2t) | EN | 22.05k | None | 1024 / 256 / None | [MoL WaveNet](https://github.com/r9y9/wavenet_vocoder) |
| [ljspeech.parallel_wavegan.v1](https://drive.google.com/open?id=1tv9GKyRT4CDsvUWKwH3s_OfXkiTi0gw7) | EN | 22.05k | None | 1024 / 256 / None | [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) |
| [ljspeech.wavenet.mol.v2](https://drive.google.com/open?id=1es2HuKUeKVtEdq6YDtAsLNpqCy4fhIXr) | EN | 22.05k | 80-7600 | 1024 / 256 / None | [MoL WaveNet](https://github.com/r9y9/wavenet_vocoder) |
| [ljspeech.parallel_wavegan.v2](https://drive.google.com/open?id=1Grn7X9wD35UcDJ5F7chwdTqTa4U7DeVB) | EN | 22.05k | 80-7600 | 1024 / 256 / None | [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) |
| [ljspeech.melgan.v1](https://drive.google.com/open?id=1ipPWYl8FBNRlBFaKj1-i23eQpW_W_YcR) | EN | 22.05k | 80-7600 | 1024 / 256 / None | [MelGAN](https://github.com/kan-bayashi/ParallelWaveGAN) |
| [ljspeech.melgan.v3](https://drive.google.com/open?id=1_a8faVA5OGCzIcJNw4blQYjfG4oA9VEt) | EN | 22.05k | 80-7600 | 1024 / 256 / None | [MelGAN](https://github.com/kan-bayashi/ParallelWaveGAN) |
| [libritts.wavenet.mol.v1](https://drive.google.com/open?id=1jHUUmQFjWiQGyDd7ZeiCThSjjpbF_B4h) | EN | 24k | None | 1024 / 256 / None | [MoL WaveNet](https://github.com/r9y9/wavenet_vocoder) |
| [jsut.wavenet.mol.v1](https://drive.google.com/open?id=187xvyNbmJVZ0EZ1XHCdyjZHTXK9EcfkK) | JP | 24k | 80-7600 | 2048 / 300 / 1200 | [MoL WaveNet](https://github.com/r9y9/wavenet_vocoder) |
| [jsut.parallel_wavegan.v1](https://drive.google.com/open?id=1OwrUQzAmvjj1x9cDhnZPp6dqtsEqGEJM) | JP | 24k | 80-7600 | 2048 / 300 / 1200 | [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) |
| [csmsc.wavenet.mol.v1](https://drive.google.com/open?id=1PsjFRV5eUP0HHwBaRYya9smKy5ghXKzj) | ZH | 24k | 80-7600 | 2048 / 300 / 1200 | [MoL WaveNet](https://github.com/r9y9/wavenet_vocoder) |
| [csmsc.parallel_wavegan.v1](https://drive.google.com/open?id=10M6H88jEUGbRWBmU1Ff2VaTmOAeL8CEy) | ZH | 24k | 80-7600 | 2048 / 300 / 1200 | [Parallel WaveGAN](https://github.com/kan-bayashi/ParallelWaveGAN) |
If you want to use the above pretrained vocoders, please exactly match the feature setting with them.
</div></details>
### TTS demo
<details><summary>ESPnet2</summary><div>
You can try the real-time demo in Google Colab.
Please access the notebook from the following button and enjoy the real-time synthesis!
- Real-time TTS demo with ESPnet2 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/espnet/notebook/blob/master/espnet2_tts_realtime_demo.ipynb)
English, Japanese, and Mandarin models are available in the demo.
</div></details>
<details><summary>ESPnet1</summary><div>
> NOTE: We are moving on ESPnet2-based development for TTS. Please check the latest demo in the above ESPnet2 demo.
You can try the real-time demo in Google Colab.
Please access the notebook from the following button and enjoy the real-time synthesis.
- Real-time TTS demo with ESPnet1 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/espnet/notebook/blob/master/tts_realtime_demo.ipynb)
We also provide shell script to perform synthesize.
Go to a recipe directory and run `utils/synth_wav.sh` as follows:
```sh
# go to recipe directory and source path of espnet tools
cd egs/ljspeech/tts1 && . ./path.sh
# we use upper-case char sequence for the default model.
echo "THIS IS A DEMONSTRATION OF TEXT TO SPEECH." > example.txt
# let's synthesize speech!
synth_wav.sh example.txt
# also you can use multiple sentences
echo "THIS IS A DEMONSTRATION OF TEXT TO SPEECH." > example_multi.txt
echo "TEXT TO SPEECH IS A TECHNIQUE TO CONVERT TEXT INTO SPEECH." >> example_multi.txt
synth_wav.sh example_multi.txt
```
You can change the pretrained model as follows:
```sh
synth_wav.sh --models ljspeech.fastspeech.v1 example.txt
```
Waveform synthesis is performed with Griffin-Lim algorithm and neural vocoders (WaveNet and ParallelWaveGAN).
You can change the pretrained vocoder model as follows:
```sh
synth_wav.sh --vocoder_models ljspeech.wavenet.mol.v1 example.txt
```
WaveNet vocoder provides very high quality speech but it takes time to generate.
See more details or available models via `--help`.
```sh
synth_wav.sh --help
```
</div></details>
### VC results
<details><summary>expand</summary><div>
- Transformer and Tacotron2 based VC
You can listen to some samples on the [demo webpage](https://unilight.github.io/Publication-Demos/publications/transformer-vc/).
- Cascade ASR+TTS as one of the baseline systems of VCC2020
The [Voice Conversion Challenge 2020](http://www.vc-challenge.org/) (VCC2020) adopts ESPnet to build an end-to-end based baseline system.
In VCC2020, the objective is intra/cross lingual nonparallel VC.
You can download converted samples of the cascade ASR+TTS baseline system [here](https://drive.google.com/drive/folders/1oeZo83GrOgtqxGwF7KagzIrfjr8X59Ue?usp=sharing).
</div></details>
### SLU results
<details><summary>expand</summary><div>
We list the performance on various SLU tasks and dataset using the metric reported in the original dataset paper
| Task | Dataset | Metric | Result | Pretrained Model |
| ----------------------------------------------------------------- | :-------------: | :-------------: | :-------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| Intent Classification | SLURP | Acc | 86.3 | [link](https://github.com/espnet/espnet/tree/master/egs2/slurp/asr1/README.md) |
| Intent Classification | FSC | Acc | 99.6 | [link](https://github.com/espnet/espnet/tree/master/egs2/fsc/asr1/README.md) |
| Intent Classification | FSC Unseen Speaker Set | Acc | 98.6 | [link](https://github.com/espnet/espnet/tree/master/egs2/fsc_unseen/asr1/README.md) |
| Intent Classification | FSC Unseen Utterance Set | Acc | 86.4 | [link](https://github.com/espnet/espnet/tree/master/egs2/fsc_unseen/asr1/README.md) |
| Intent Classification | FSC Challenge Speaker Set | Acc | 97.5 | [link](https://github.com/espnet/espnet/tree/master/egs2/fsc_challenge/asr1/README.md) |
| Intent Classification | FSC Challenge Utterance Set | Acc | 78.5 | [link](https://github.com/espnet/espnet/tree/master/egs2/fsc_challenge/asr1/README.md) |
| Intent Classification | SNIPS | F1 | 91.7 | [link](https://github.com/espnet/espnet/tree/master/egs2/snips/asr1/README.md) |
| Intent Classification | Grabo (Nl) | Acc | 97.2 | [link](https://github.com/espnet/espnet/tree/master/egs2/grabo/asr1/README.md) |
| Intent Classification | CAT SLU MAP (Zn) | Acc | 78.9 | [link](https://github.com/espnet/espnet/tree/master/egs2/catslu/asr1/README.md) |
| Intent Classification | Google Speech Commands | Acc | 98.4 | [link](https://github.com/espnet/espnet/tree/master/egs2/speechcommands/asr1/README.md) |
| Slot Filling | SLURP | SLU-F1 | 71.9 | [link](https://github.com/espnet/espnet/tree/master/egs2/slurp_entity/asr1/README.md) |
| Dialogue Act Classification | Switchboard | Acc | 67.5 | [link](https://github.com/espnet/espnet/tree/master/egs2/swbd_da/asr1/README.md) |
| Dialogue Act Classification | Jdcinal (Jp) | Acc | 67.4 | [link](https://github.com/espnet/espnet/tree/master/egs2/jdcinal/asr1/README.md) |
| Emotion Recognition | IEMOCAP | Acc | 69.4 | [link](https://github.com/espnet/espnet/tree/master/egs2/iemocap/asr1/README.md) |
| Emotion Recognition | swbd_sentiment | Macro F1 | 61.4 | [link](https://github.com/espnet/espnet/tree/master/egs2/swbd_sentiment/asr1/README.md) |
| Emotion Recognition | slue_voxceleb | Macro F1 | 44.0 | [link](https://github.com/espnet/espnet/tree/master/egs2/slue-voxceleb/asr1/README.md) |
If you want to check the results of the other recipes, please check `egs2/<name_of_recipe>/asr1/RESULTS.md`.
</div></details>
### CTC Segmentation demo
<details><summary>ESPnet1</summary><div>
[CTC segmentation](https://arxiv.org/abs/2007.09127) determines utterance segments within audio files.
Aligned utterance segments constitute the labels of speech datasets.
As demo, we align start and end of utterances within the audio file `ctc_align_test.wav`, using the example script `utils/asr_align_wav.sh`.
For preparation, set up a data directory:
```sh
cd egs/tedlium2/align1/
# data directory
align_dir=data/demo
mkdir -p ${align_dir}
# wav file
base=ctc_align_test
wav=../../../test_utils/${base}.wav
# recipe files
echo "batchsize: 0" > ${align_dir}/align.yaml
cat << EOF > ${align_dir}/utt_text
${base} THE SALE OF THE HOTELS
${base} IS PART OF HOLIDAY'S STRATEGY
${base} TO SELL OFF ASSETS
${base} AND CONCENTRATE
${base} ON PROPERTY MANAGEMENT
EOF
```
Here, `utt_text` is the file containing the list of utterances.
Choose a pre-trained ASR model that includes a CTC layer to find utterance segments:
```sh
# pre-trained ASR model
model=wsj.transformer_small.v1
mkdir ./conf && cp ../../wsj/asr1/conf/no_preprocess.yaml ./conf
../../../utils/asr_align_wav.sh \
--models ${model} \
--align_dir ${align_dir} \
--align_config ${align_dir}/align.yaml \
${wav} ${align_dir}/utt_text
```
Segments are written to `aligned_segments` as a list of file/utterance name, utterance start and end times in seconds and a confidence score.
The confidence score is a probability in log space that indicates how good the utterance was aligned. If needed, remove bad utterances:
```sh
min_confidence_score=-5
awk -v ms=${min_confidence_score} '{ if ($5 > ms) {print} }' ${align_dir}/aligned_segments
```
The demo script `utils/ctc_align_wav.sh` uses an already pretrained ASR model (see list above for more models).
It is recommended to use models with RNN-based encoders (such as BLSTMP) for aligning large audio files;
rather than using Transformer models that have a high memory consumption on longer audio data.
The sample rate of the audio must be consistent with that of the data used in training; adjust with `sox` if needed.
A full example recipe is in `egs/tedlium2/align1/`.
</div></details>
<details><summary>ESPnet2</summary><div>
[CTC segmentation](https://arxiv.org/abs/2007.09127) determines utterance segments within audio files.
Aligned utterance segments constitute the labels of speech datasets.
As demo, we align start and end of utterances within the audio file `ctc_align_test.wav`.
This can be done either directly from the Python command line or using the script `espnet2/bin/asr_align.py`.
From the Python command line interface:
```python
# load a model with character tokens
from espnet_model_zoo.downloader import ModelDownloader
d = ModelDownloader(cachedir="./modelcache")
wsjmodel = d.download_and_unpack("kamo-naoyuki/wsj")
# load the example file included in the ESPnet repository
import soundfile
speech, rate = soundfile.read("./test_utils/ctc_align_test.wav")
# CTC segmentation
from espnet2.bin.asr_align import CTCSegmentation
aligner = CTCSegmentation( **wsjmodel , fs=rate )
text = """
utt1 THE SALE OF THE HOTELS
utt2 IS PART OF HOLIDAY'S STRATEGY
utt3 TO SELL OFF ASSETS
utt4 AND CONCENTRATE ON PROPERTY MANAGEMENT
"""
segments = aligner(speech, text)
print(segments)
# utt1 utt 0.26 1.73 -0.0154 THE SALE OF THE HOTELS
# utt2 utt 1.73 3.19 -0.7674 IS PART OF HOLIDAY'S STRATEGY
# utt3 utt 3.19 4.20 -0.7433 TO SELL OFF ASSETS
# utt4 utt 4.20 6.10 -0.4899 AND CONCENTRATE ON PROPERTY MANAGEMENT
```
Aligning also works with fragments of the text.
For this, set the `gratis_blank` option that allows skipping unrelated audio sections without penalty.
It's also possible to omit the utterance names at the beginning of each line, by setting `kaldi_style_text` to False.
```python
aligner.set_config( gratis_blank=True, kaldi_style_text=False )
text = ["SALE OF THE HOTELS", "PROPERTY MANAGEMENT"]
segments = aligner(speech, text)
print(segments)
# utt_0000 utt 0.37 1.72 -2.0651 SALE OF THE HOTELS
# utt_0001 utt 4.70 6.10 -5.0566 PROPERTY MANAGEMENT
```
The script `espnet2/bin/asr_align.py` uses a similar interface. To align utterances:
```sh
# ASR model and config files from pretrained model (e.g. from cachedir):
asr_config=<path-to-model>/config.yaml
asr_model=<path-to-model>/valid.*best.pth
# prepare the text file
wav="test_utils/ctc_align_test.wav"
text="test_utils/ctc_align_text.txt"
cat << EOF > ${text}
utt1 THE SALE OF THE HOTELS
utt2 IS PART OF HOLIDAY'S STRATEGY
utt3 TO SELL OFF ASSETS
utt4 AND CONCENTRATE
utt5 ON PROPERTY MANAGEMENT
EOF
# obtain alignments:
python espnet2/bin/asr_align.py --asr_train_config ${asr_config} --asr_model_file ${asr_model} --audio ${wav} --text ${text}
# utt1 ctc_align_test 0.26 1.73 -0.0154 THE SALE OF THE HOTELS
# utt2 ctc_align_test 1.73 3.19 -0.7674 IS PART OF HOLIDAY'S STRATEGY
# utt3 ctc_align_test 3.19 4.20 -0.7433 TO SELL OFF ASSETS
# utt4 ctc_align_test 4.20 4.97 -0.6017 AND CONCENTRATE
# utt5 ctc_align_test 4.97 6.10 -0.3477 ON PROPERTY MANAGEMENT
```
The output of the script can be redirected to a `segments` file by adding the argument `--output segments`.
Each line contains file/utterance name, utterance start and end times in seconds and a confidence score; optionally also the utterance text.
The confidence score is a probability in log space that indicates how good the utterance was aligned. If needed, remove bad utterances:
```sh
min_confidence_score=-7
# here, we assume that the output was written to the file `segments`
awk -v ms=${min_confidence_score} '{ if ($5 > ms) {print} }' segments
```
See the module documentation for more information.
It is recommended to use models with RNN-based encoders (such as BLSTMP) for aligning large audio files;
rather than using Transformer models that have a high memory consumption on longer audio data.
The sample rate of the audio must be consistent with that of the data used in training; adjust with `sox` if needed.
Also, we can use this tool to provide token-level segmentation information if we prepare a list of tokens instead of that of utterances in the `text` file. See the discussion in https://github.com/espnet/espnet/issues/4278#issuecomment-1100756463.
</div></details>
## Citations
```
@inproceedings{watanabe2018espnet,
author={Shinji Watanabe and Takaaki Hori and Shigeki Karita and Tomoki Hayashi and Jiro Nishitoba and Yuya Unno and Nelson {Enrique Yalta Soplin} and Jahn Heymann and Matthew Wiesner and Nanxin Chen and Adithya Renduchintala and Tsubasa Ochiai},
title={{ESPnet}: End-to-End Speech Processing Toolkit},
year={2018},
booktitle={Proceedings of Interspeech},
pages={2207--2211},
doi={10.21437/Interspeech.2018-1456},
url={http://dx.doi.org/10.21437/Interspeech.2018-1456}
}
@inproceedings{hayashi2020espnet,
title={{Espnet-TTS}: Unified, reproducible, and integratable open source end-to-end text-to-speech toolkit},
author={Hayashi, Tomoki and Yamamoto, Ryuichi and Inoue, Katsuki and Yoshimura, Takenori and Watanabe, Shinji and Toda, Tomoki and Takeda, Kazuya and Zhang, Yu and Tan, Xu},
booktitle={Proceedings of IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={7654--7658},
year={2020},
organization={IEEE}
}
@inproceedings{inaguma-etal-2020-espnet,
title = "{ESP}net-{ST}: All-in-One Speech Translation Toolkit",
author = "Inaguma, Hirofumi and
Kiyono, Shun and
Duh, Kevin and
Karita, Shigeki and
Yalta, Nelson and
Hayashi, Tomoki and
Watanabe, Shinji",
booktitle = "Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics: System Demonstrations",
month = jul,
year = "2020",
address = "Online",
publisher = "Association for Computational Linguistics",
url = "https://www.aclweb.org/anthology/2020.acl-demos.34",
pages = "302--311",
}
@inproceedings{li2020espnet,
title={{ESPnet-SE}: End-to-End Speech Enhancement and Separation Toolkit Designed for {ASR} Integration},
author={Chenda Li and Jing Shi and Wangyou Zhang and Aswin Shanmugam Subramanian and Xuankai Chang and Naoyuki Kamo and Moto Hira and Tomoki Hayashi and Christoph Boeddeker and Zhuo Chen and Shinji Watanabe},
booktitle={Proceedings of IEEE Spoken Language Technology Workshop (SLT)},
pages={785--792},
year={2021},
organization={IEEE},
}
@inproceedings{arora2021espnet,
title={{ESPnet-SLU}: Advancing Spoken Language Understanding through ESPnet},
author={Arora, Siddhant and Dalmia, Siddharth and Denisov, Pavel and Chang, Xuankai and Ueda, Yushi and Peng, Yifan and Zhang, Yuekai and Kumar, Sujay and Ganesan, Karthik and Yan, Brian and others},
booktitle={ICASSP 2022-2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={7167--7171},
year={2022},
organization={IEEE}
}
@inproceedings{shi2022muskits,
author={Shi, Jiatong and Guo, Shuai and Qian, Tao and Huo, Nan and Hayashi, Tomoki and Wu, Yuning and Xu, Frank and Chang, Xuankai and Li, Huazhe and Wu, Peter and Watanabe, Shinji and Jin, Qin},
title={{Muskits}: an End-to-End Music Processing Toolkit for Singing Voice Synthesis},
year={2022},
booktitle={Proceedings of Interspeech},
pages={4277-4281},
url={https://www.isca-speech.org/archive/pdfs/interspeech_2022/shi22d_interspeech.pdf}
}
@article{gao2022euro,
title={{EURO}: {ESPnet} Unsupervised ASR Open-source Toolkit},
author={Gao, Dongji and Shi, Jiatong and Chuang, Shun-Po and Garcia, Leibny Paola and Lee, Hung-yi and Watanabe, Shinji and Khudanpur, Sanjeev},
journal={arXiv preprint arXiv:2211.17196},
year={2022}
}
```
"""Initialize espnet package."""
import os
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
__version__ = f.read().strip()
#!/usr/bin/env python3
"""
This script is used to provide utility functions designed for multi-speaker ASR.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
Most functions can be directly used as in asr_utils.py:
CompareValueTrigger, restore_snapshot, adadelta_eps_decay, chainer_load,
torch_snapshot, torch_save, torch_resume, AttributeDict, get_model_conf.
"""
import copy
import logging
import os
from chainer.training import extension
from espnet.asr.asr_utils import parse_hypothesis
# * -------------------- chainer extension related -------------------- *
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, dict[str, Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
CustomConverter object. Function to convert data.
device (torch.device): The destination device to send tensor.
reverse (bool): If True, input and output length are reversed.
"""
def __init__(self, att_vis_fn, data, outdir, converter, device, reverse=False):
"""Initialize PlotAttentionReport."""
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.outdir = outdir
self.converter = converter
self.device = device
self.reverse = reverse
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save imaged matrix of att_ws."""
att_ws_sd = self.get_attention_weights()
for ns, att_ws in enumerate(att_ws_sd):
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.output%d.png" % (
self.outdir,
self.data[idx][0],
ns + 1,
)
att_w = self.get_attention_weight(idx, att_w, ns)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of attention matrix to tensorboard."""
att_ws_sd = self.get_attention_weights()
for ns, att_ws in enumerate(att_ws_sd):
for idx, att_w in enumerate(att_ws):
att_w = self.get_attention_weight(idx, att_w, ns)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (self.data[idx][0]), plot.gcf(), step)
plot.clf()
def get_attention_weights(self):
"""Return attention weights.
Returns:
arr_ws_sd (numpy.ndarray): attention weights. It's shape would be
differ from bachend.dtype=float
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax). 2)
other case => (B, Lmax, Tmax).
* chainer-> attention weights (B, Lmax, Tmax).
"""
batch = self.converter([self.converter.transform(self.data)], self.device)
att_ws_sd = self.att_vis_fn(*batch)
return att_ws_sd
def get_attention_weight(self, idx, att_w, spkr_idx):
"""Transform attention weight in regard to self.reverse."""
if self.reverse:
dec_len = int(self.data[idx][1]["input"][0]["shape"][0])
enc_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
else:
dec_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
enc_len = int(self.data[idx][1]["input"][0]["shape"][0])
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Visualize attention weights matrix.
Args:
att_w(Tensor): Attention weight matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename):
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
def add_results_to_json(js, nbest_hyps_sd, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers (# Utts x # Spkrs).
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
num_spkrs = len(nbest_hyps_sd)
new_js["output"] = []
for ns in range(num_spkrs):
tmp_js = []
nbest_hyps = nbest_hyps_sd[ns]
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
# copy ground-truth
out_dic = dict(js["output"][ns].items())
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add to list of N-best result dicts
tmp_js.append(out_dic)
# show 1-best result
if n == 1:
logging.info("groundtruth: %s" % out_dic["text"])
logging.info("prediction : %s" % out_dic["rec_text"])
new_js["output"].append(tmp_js)
return new_js
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import copy
import json
import logging
import os
import shutil
import tempfile
import numpy as np
import torch
# * -------------------- training iterator related -------------------- *
class CompareValueTrigger(object):
"""Trigger invoked when key value getting bigger or lower than before.
Args:
key (str) : Key of value.
compare_fn ((float, float) -> bool) : Function to compare the values.
trigger (tuple(int, str)) : Trigger that decide the comparison interval.
"""
def __init__(self, key, compare_fn, trigger=(1, "epoch")):
from chainer import training
self._key = key
self._best_value = None
self._interval_trigger = training.util.get_trigger(trigger)
self._init_summary()
self._compare_fn = compare_fn
def __call__(self, trainer):
"""Get value related to the key and compare with current value."""
observation = trainer.observation
summary = self._summary
key = self._key
if key in observation:
summary.add({key: observation[key]})
if not self._interval_trigger(trainer):
return False
stats = summary.compute_mean()
value = float(stats[key]) # copy to CPU
self._init_summary()
if self._best_value is None:
# initialize best value
self._best_value = value
return False
elif self._compare_fn(self._best_value, value):
return True
else:
self._best_value = value
return False
def _init_summary(self):
import chainer
self._summary = chainer.reporter.DictSummary()
try:
from chainer.training import extension
except ImportError:
PlotAttentionReport = None
else:
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
att_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1,
):
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of att_ws matrix."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
filename = "%s/%s.ep.{.updater.epoch}.att%d.png" % (
self.outdir,
uttid_list[idx],
i + 1,
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.att%d.npy" % (
self.outdir,
uttid_list[idx],
i + 1,
)
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
# han
for idx, att_w in enumerate(att_ws[num_encs]):
filename = "%s/%s.ep.{.updater.epoch}.han.png" % (
self.outdir,
uttid_list[idx],
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.han.npy" % (
self.outdir,
uttid_list[idx],
)
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(
att_w, filename.format(trainer), han_mode=True
)
else:
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.png" % (
self.outdir,
uttid_list[idx],
)
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir,
uttid_list[idx],
)
np.save(np_filename.format(trainer), att_w)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of att_ws matrix to the tensorboard."""
att_ws, uttid_list = self.get_attention_weights()
if isinstance(att_ws, list): # multi-encoder case
num_encs = len(att_ws) - 1
# atts
for i in range(num_encs):
for idx, att_w in enumerate(att_ws[i]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure(
"%s_att%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step,
)
# han
for idx, att_w in enumerate(att_ws[num_encs]):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_han_plot(att_w)
logger.add_figure(
"%s_han" % (uttid_list[idx]),
plot.gcf(),
step,
)
else:
for idx, att_w in enumerate(att_ws):
att_w = self.trim_attention_weight(uttid_list[idx], att_w)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_attention_weights(self):
"""Return attention weights.
Returns:
numpy.ndarray: attention weights. float. Its shape would be
differ from backend.
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax), 2)
other case => (B, Lmax, Tmax).
* chainer-> (B, Lmax, Tmax)
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
att_ws = self.att_vis_fn(*batch)
else:
att_ws = self.att_vis_fn(**batch)
return att_ws, uttid_list
def trim_attention_weight(self, uttid, att_w):
"""Transform attention matrix with regard to self.reverse."""
if self.reverse:
enc_key, enc_axis = self.okey, self.oaxis
dec_key, dec_axis = self.ikey, self.iaxis
else:
enc_key, enc_axis = self.ikey, self.iaxis
dec_key, dec_axis = self.okey, self.oaxis
dec_len = int(self.data_dict[uttid][dec_key][dec_axis]["shape"][0])
enc_len = int(self.data_dict[uttid][enc_key][enc_axis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Plot the att_w matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
att_w = att_w.astype(np.float32)
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def draw_han_plot(self, att_w):
"""Plot the att_w matrix for hierarchical attention.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
plt.clf()
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
legends = []
plt.subplot(1, len(att_w), h)
for i in range(aw.shape[1]):
plt.plot(aw[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, aw.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
else:
legends = []
for i in range(att_w.shape[1]):
plt.plot(att_w[:, i])
legends.append("Att{}".format(i))
plt.ylim([0, 1.0])
plt.xlim([0, att_w.shape[0]])
plt.grid(True)
plt.ylabel("Attention Weight")
plt.xlabel("Decoder Index")
plt.legend(legends)
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename, han_mode=False):
if han_mode:
plt = self.draw_han_plot(att_w)
else:
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
try:
from chainer.training import extension
except ImportError:
PlotCTCReport = None
else:
class PlotCTCReport(extension.Extension):
"""Plot CTC reporter.
Args:
ctc_vis_fn (espnet.nets.*_backend.e2e_asr.E2E.calculate_all_ctc_probs):
Function of CTC visualization.
data (list[tuple(str, dict[str, list[Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
Function to convert data.
device (int | torch.device): Device.
reverse (bool): If True, input and output length are reversed.
ikey (str): Key to access input
(for ASR/ST ikey="input", for MT ikey="output".)
iaxis (int): Dimension to access input
(for ASR/ST iaxis=0, for MT iaxis=1.)
okey (str): Key to access output
(for ASR/ST okey="input", MT okay="output".)
oaxis (int): Dimension to access output
(for ASR/ST oaxis=0, for MT oaxis=0.)
subsampling_factor (int): subsampling factor in encoder
"""
def __init__(
self,
ctc_vis_fn,
data,
outdir,
converter,
transform,
device,
reverse=False,
ikey="input",
iaxis=0,
okey="output",
oaxis=0,
subsampling_factor=1,
):
self.ctc_vis_fn = ctc_vis_fn
self.data = copy.deepcopy(data)
self.data_dict = {k: v for k, v in copy.deepcopy(data)}
# key is utterance ID
self.outdir = outdir
self.converter = converter
self.transform = transform
self.device = device
self.reverse = reverse
self.ikey = ikey
self.iaxis = iaxis
self.okey = okey
self.oaxis = oaxis
self.factor = subsampling_factor
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save image file of ctc prob."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
filename = "%s/%s.ep.{.updater.epoch}.ctc%d.png" % (
self.outdir,
uttid_list[idx],
i + 1,
)
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.ctc%d.npy" % (
self.outdir,
uttid_list[idx],
i + 1,
)
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
else:
for idx, ctc_prob in enumerate(ctc_probs):
filename = "%s/%s.ep.{.updater.epoch}.png" % (
self.outdir,
uttid_list[idx],
)
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
np_filename = "%s/%s.ep.{.updater.epoch}.npy" % (
self.outdir,
uttid_list[idx],
)
np.save(np_filename.format(trainer), ctc_prob)
self._plot_and_save_ctc(ctc_prob, filename.format(trainer))
def log_ctc_probs(self, logger, step):
"""Add image files of ctc probs to the tensorboard."""
ctc_probs, uttid_list = self.get_ctc_probs()
if isinstance(ctc_probs, list): # multi-encoder case
num_encs = len(ctc_probs) - 1
for i in range(num_encs):
for idx, ctc_prob in enumerate(ctc_probs[i]):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure(
"%s_ctc%d" % (uttid_list[idx], i + 1),
plot.gcf(),
step,
)
else:
for idx, ctc_prob in enumerate(ctc_probs):
ctc_prob = self.trim_ctc_prob(uttid_list[idx], ctc_prob)
plot = self.draw_ctc_plot(ctc_prob)
logger.add_figure("%s" % (uttid_list[idx]), plot.gcf(), step)
def get_ctc_probs(self):
"""Return CTC probs.
Returns:
numpy.ndarray: CTC probs. float. Its shape would be
differ from backend. (B, Tmax, vocab).
"""
return_batch, uttid_list = self.transform(self.data, return_uttid=True)
batch = self.converter([return_batch], self.device)
if isinstance(batch, tuple):
probs = self.ctc_vis_fn(*batch)
else:
probs = self.ctc_vis_fn(**batch)
return probs, uttid_list
def trim_ctc_prob(self, uttid, prob):
"""Trim CTC posteriors accoding to input lengths."""
enc_len = int(self.data_dict[uttid][self.ikey][self.iaxis]["shape"][0])
if self.factor > 1:
enc_len //= self.factor
prob = prob[:enc_len]
return prob
def draw_ctc_plot(self, ctc_prob):
"""Plot the ctc_prob matrix.
Returns:
matplotlib.pyplot: pyplot object with CTC prob matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ctc_prob = ctc_prob.astype(np.float32)
plt.clf()
topk_ids = np.argsort(ctc_prob, axis=1)
n_frames, vocab = ctc_prob.shape
times_probs = np.arange(n_frames)
plt.figure(figsize=(20, 8))
# NOTE: index 0 is reserved for blank
for idx in set(topk_ids.reshape(-1).tolist()):
if idx == 0:
plt.plot(
times_probs, ctc_prob[:, 0], ":", label="<blank>", color="grey"
)
else:
plt.plot(times_probs, ctc_prob[:, idx])
plt.xlabel("Input [frame]", fontsize=12)
plt.ylabel("Posteriors", fontsize=12)
plt.xticks(list(range(0, int(n_frames) + 1, 10)))
plt.yticks(list(range(0, 2, 1)))
plt.tight_layout()
return plt
def _plot_and_save_ctc(self, ctc_prob, filename):
plt = self.draw_ctc_plot(ctc_prob)
plt.savefig(filename)
plt.close()
def restore_snapshot(model, snapshot, load_fn=None):
"""Extension to restore snapshot.
Returns:
An extension function.
"""
import chainer
from chainer import training
if load_fn is None:
load_fn = chainer.serializers.load_npz
@training.make_extension(trigger=(1, "epoch"))
def restore_snapshot(trainer):
_restore_snapshot(model, snapshot, load_fn)
return restore_snapshot
def _restore_snapshot(model, snapshot, load_fn=None):
if load_fn is None:
import chainer
load_fn = chainer.serializers.load_npz
load_fn(snapshot, model)
logging.info("restored from " + str(snapshot))
def adadelta_eps_decay(eps_decay):
"""Extension to perform adadelta eps decay.
Args:
eps_decay (float): Decay rate of eps.
Returns:
An extension function.
"""
from chainer import training
@training.make_extension(trigger=(1, "epoch"))
def adadelta_eps_decay(trainer):
_adadelta_eps_decay(trainer, eps_decay)
return adadelta_eps_decay
def _adadelta_eps_decay(trainer, eps_decay):
optimizer = trainer.updater.get_optimizer("main")
# for chainer
if hasattr(optimizer, "eps"):
current_eps = optimizer.eps
setattr(optimizer, "eps", current_eps * eps_decay)
logging.info("adadelta eps decayed to " + str(optimizer.eps))
# pytorch
else:
for p in optimizer.param_groups:
p["eps"] *= eps_decay
logging.info("adadelta eps decayed to " + str(p["eps"]))
def adam_lr_decay(eps_decay):
"""Extension to perform adam lr decay.
Args:
eps_decay (float): Decay rate of lr.
Returns:
An extension function.
"""
from chainer import training
@training.make_extension(trigger=(1, "epoch"))
def adam_lr_decay(trainer):
_adam_lr_decay(trainer, eps_decay)
return adam_lr_decay
def _adam_lr_decay(trainer, eps_decay):
optimizer = trainer.updater.get_optimizer("main")
# for chainer
if hasattr(optimizer, "lr"):
current_lr = optimizer.lr
setattr(optimizer, "lr", current_lr * eps_decay)
logging.info("adam lr decayed to " + str(optimizer.lr))
# pytorch
else:
for p in optimizer.param_groups:
p["lr"] *= eps_decay
logging.info("adam lr decayed to " + str(p["lr"]))
def torch_snapshot(savefun=torch.save, filename="snapshot.ep.{.updater.epoch}"):
"""Extension to take snapshot of the trainer for pytorch.
Returns:
An extension function.
"""
from chainer.training import extension
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
def torch_snapshot(trainer):
_torch_snapshot_object(trainer, trainer, filename.format(trainer), savefun)
return torch_snapshot
def _torch_snapshot_object(trainer, target, filename, savefun):
from chainer.serializers import DictionarySerializer
# make snapshot_dict dictionary
s = DictionarySerializer()
s.save(trainer)
if hasattr(trainer.updater.model, "model"):
# (for TTS)
if hasattr(trainer.updater.model.model, "module"):
model_state_dict = trainer.updater.model.model.module.state_dict()
else:
model_state_dict = trainer.updater.model.model.state_dict()
else:
# (for ASR)
if hasattr(trainer.updater.model, "module"):
model_state_dict = trainer.updater.model.module.state_dict()
else:
model_state_dict = trainer.updater.model.state_dict()
snapshot_dict = {
"trainer": s.target,
"model": model_state_dict,
"optimizer": trainer.updater.get_optimizer("main").state_dict(),
}
# save snapshot dictionary
fn = filename.format(trainer)
prefix = "tmp" + fn
tmpdir = tempfile.mkdtemp(prefix=prefix, dir=trainer.out)
tmppath = os.path.join(tmpdir, fn)
try:
savefun(snapshot_dict, tmppath)
shutil.move(tmppath, os.path.join(trainer.out, fn))
finally:
shutil.rmtree(tmpdir)
def add_gradient_noise(model, iteration, duration=100, eta=1.0, scale_factor=0.55):
"""Adds noise from a standard normal distribution to the gradients.
The standard deviation (`sigma`) is controlled by the three hyper-parameters below.
`sigma` goes to zero (no noise) with more iterations.
Args:
model (torch.nn.model): Model.
iteration (int): Number of iterations.
duration (int) {100, 1000}:
Number of durations to control the interval of the `sigma` change.
eta (float) {0.01, 0.3, 1.0}: The magnitude of `sigma`.
scale_factor (float) {0.55}: The scale of `sigma`.
"""
interval = (iteration // duration) + 1
sigma = eta / interval**scale_factor
for param in model.parameters():
if param.grad is not None:
_shape = param.grad.size()
noise = sigma * torch.randn(_shape).to(param.device)
param.grad += noise
# * -------------------- general -------------------- *
def get_model_conf(model_path, conf_path=None):
"""Get model config information by reading a model config file (model.json).
Args:
model_path (str): Model path.
conf_path (str): Optional model config path.
Returns:
list[int, int, dict[str, Any]]: Config information loaded from json file.
"""
if conf_path is None:
model_conf = os.path.dirname(model_path) + "/model.json"
else:
model_conf = conf_path
with open(model_conf, "rb") as f:
logging.info("reading a config file from " + model_conf)
confs = json.load(f)
if isinstance(confs, dict):
# for lm
args = confs
return argparse.Namespace(**args)
else:
# for asr, tts, mt
idim, odim, args = confs
return idim, odim, argparse.Namespace(**args)
def chainer_load(path, model):
"""Load chainer model parameters.
Args:
path (str): Model path or snapshot file path to be loaded.
model (chainer.Chain): Chainer model.
"""
import chainer
if "snapshot" in os.path.basename(path):
chainer.serializers.load_npz(path, model, path="updater/model:main/")
else:
chainer.serializers.load_npz(path, model)
def torch_save(path, model):
"""Save torch model states.
Args:
path (str): Model path to be saved.
model (torch.nn.Module): Torch model.
"""
if hasattr(model, "module"):
torch.save(model.module.state_dict(), path)
else:
torch.save(model.state_dict(), path)
def snapshot_object(target, filename):
"""Returns a trainer extension to take snapshots of a given object.
Args:
target (model): Object to serialize.
filename (str): Name of the file into which the object is serialized.It can
be a format string, where the trainer object is passed to
the :meth: `str.format` method. For example,
``'snapshot_{.updater.iteration}'`` is converted to
``'snapshot_10000'`` at the 10,000th iteration.
Returns:
An extension function.
"""
from chainer.training import extension
@extension.make_extension(trigger=(1, "epoch"), priority=-100)
def snapshot_object(trainer):
torch_save(os.path.join(trainer.out, filename.format(trainer)), target)
return snapshot_object
def torch_load(path, model):
"""Load torch model states.
Args:
path (str): Model path or snapshot file path to be loaded.
model (torch.nn.Module): Torch model.
"""
if "snapshot" in os.path.basename(path):
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)[
"model"
]
else:
model_state_dict = torch.load(path, map_location=lambda storage, loc: storage)
if hasattr(model, "module"):
model.module.load_state_dict(model_state_dict)
else:
model.load_state_dict(model_state_dict)
del model_state_dict
def torch_resume(snapshot_path, trainer):
"""Resume from snapshot for pytorch.
Args:
snapshot_path (str): Snapshot file path.
trainer (chainer.training.Trainer): Chainer's trainer instance.
"""
from chainer.serializers import NpzDeserializer
# load snapshot
snapshot_dict = torch.load(snapshot_path, map_location=lambda storage, loc: storage)
# restore trainer states
d = NpzDeserializer(snapshot_dict["trainer"])
d.load(trainer)
# restore model states
if hasattr(trainer.updater.model, "model"):
# (for TTS model)
if hasattr(trainer.updater.model.model, "module"):
trainer.updater.model.model.module.load_state_dict(snapshot_dict["model"])
else:
trainer.updater.model.model.load_state_dict(snapshot_dict["model"])
else:
# (for ASR model)
if hasattr(trainer.updater.model, "module"):
trainer.updater.model.module.load_state_dict(snapshot_dict["model"])
else:
trainer.updater.model.load_state_dict(snapshot_dict["model"])
# retore optimizer states
trainer.updater.get_optimizer("main").load_state_dict(snapshot_dict["optimizer"])
# delete opened snapshot
del snapshot_dict
# * ------------------ recognition related ------------------ *
def parse_hypothesis(hyp, char_list):
"""Parse hypothesis.
Args:
hyp (list[dict[str, Any]]): Recognition hypothesis.
char_list (list[str]): List of characters.
Returns:
tuple(str, str, str, float)
"""
# remove sos and get results
tokenid_as_list = list(map(int, hyp["yseq"][1:]))
token_as_list = [char_list[idx] for idx in tokenid_as_list]
score = float(hyp["score"])
# convert to string
tokenid = " ".join([str(idx) for idx in tokenid_as_list])
token = " ".join(token_as_list)
text = "".join(token_as_list).replace("<space>", " ")
return text, token, tokenid, score
def add_results_to_json(js, nbest_hyps, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers: nutts x nspkrs.
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
new_js["output"] = []
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
# copy ground-truth
if len(js["output"]) > 0:
out_dic = dict(js["output"][0].items())
else:
# for no reference case (e.g., speech translation)
out_dic = {"name": ""}
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add to list of N-best result dicts
new_js["output"].append(out_dic)
# show 1-best result
if n == 1:
if "text" in out_dic.keys():
logging.info("groundtruth: %s" % out_dic["text"])
logging.info("prediction : %s" % out_dic["rec_text"])
return new_js
def plot_spectrogram(
plt,
spec,
mode="db",
fs=None,
frame_shift=None,
bottom=True,
left=True,
right=True,
top=False,
labelbottom=True,
labelleft=True,
labelright=True,
labeltop=False,
cmap="inferno",
):
"""Plot spectrogram using matplotlib.
Args:
plt (matplotlib.pyplot): pyplot object.
spec (numpy.ndarray): Input stft (Freq, Time)
mode (str): db or linear.
fs (int): Sample frequency. To convert y-axis to kHz unit.
frame_shift (int): The frame shift of stft. To convert x-axis to second unit.
bottom (bool):Whether to draw the respective ticks.
left (bool):
right (bool):
top (bool):
labelbottom (bool):Whether to draw the respective tick labels.
labelleft (bool):
labelright (bool):
labeltop (bool):
cmap (str): Colormap defined in matplotlib.
"""
spec = np.abs(spec)
if mode == "db":
x = 20 * np.log10(spec + np.finfo(spec.dtype).eps)
elif mode == "linear":
x = spec
else:
raise ValueError(mode)
if fs is not None:
ytop = fs / 2000
ylabel = "kHz"
else:
ytop = x.shape[0]
ylabel = "bin"
if frame_shift is not None and fs is not None:
xtop = x.shape[1] * frame_shift / fs
xlabel = "s"
else:
xtop = x.shape[1]
xlabel = "frame"
extent = (0, xtop, 0, ytop)
plt.imshow(x[::-1], cmap=cmap, extent=extent)
if labelbottom:
plt.xlabel("time [{}]".format(xlabel))
if labelleft:
plt.ylabel("freq [{}]".format(ylabel))
plt.colorbar().set_label("{}".format(mode))
plt.tick_params(
bottom=bottom,
left=left,
right=right,
top=top,
labelbottom=labelbottom,
labelleft=labelleft,
labelright=labelright,
labeltop=labeltop,
)
plt.axis("auto")
# * ------------------ recognition related ------------------ *
def format_mulenc_args(args):
"""Format args for multi-encoder setup.
It deals with following situations: (when args.num_encs=2):
1. args.elayers = None -> args.elayers = [4, 4];
2. args.elayers = 4 -> args.elayers = [4, 4];
3. args.elayers = [4, 4, 4] -> args.elayers = [4, 4].
"""
# default values when None is assigned.
default_dict = {
"etype": "blstmp",
"elayers": 4,
"eunits": 300,
"subsample": "1",
"dropout_rate": 0.0,
"atype": "dot",
"adim": 320,
"awin": 5,
"aheads": 4,
"aconv_chans": -1,
"aconv_filts": 100,
}
for k in default_dict.keys():
if isinstance(vars(args)[k], list):
if len(vars(args)[k]) != args.num_encs:
logging.warning(
"Length mismatch {}: Convert {} to {}.".format(
k, vars(args)[k], vars(args)[k][: args.num_encs]
)
)
vars(args)[k] = vars(args)[k][: args.num_encs]
else:
if not vars(args)[k]:
# assign default value if it is None
vars(args)[k] = default_dict[k]
logging.warning(
"{} is not specified, use default value {}.".format(
k, default_dict[k]
)
)
# duplicate
logging.warning(
"Type mismatch {}: Convert {} to {}.".format(
k, vars(args)[k], [vars(args)[k] for _ in range(args.num_encs)]
)
)
vars(args)[k] = [vars(args)[k] for _ in range(args.num_encs)]
return args
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Training/decoding definition for the speech recognition task."""
import json
import logging
import os
# chainer related
import chainer
from chainer import training
from chainer.datasets import TransformDataset
from chainer.training import extensions
# rnnlm
import espnet.lm.chainer_backend.extlm as extlm_chainer
import espnet.lm.chainer_backend.lm as lm_chainer
# espnet related
from espnet.asr.asr_utils import (
CompareValueTrigger,
adadelta_eps_decay,
add_results_to_json,
chainer_load,
get_model_conf,
restore_snapshot,
)
from espnet.nets.asr_interface import ASRInterface
from espnet.utils.deterministic_utils import set_deterministic_chainer
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import (
ShufflingEnabler,
ToggleableShufflingMultiprocessIterator,
ToggleableShufflingSerialIterator,
)
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
# display chainer version
logging.info("chainer version = " + chainer.__version__)
set_deterministic_chainer(args)
# check cuda and cudnn availability
if not chainer.cuda.available:
logging.warning("cuda is not available")
if not chainer.cuda.cudnn_enabled:
logging.warning("cudnn is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]["input"][0]["shape"][1])
odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
logging.info("#input dims : " + str(idim))
logging.info("#output dims: " + str(odim))
# specify attention, CTC, hybrid mode
if args.mtlalpha == 1.0:
mtl_mode = "ctc"
logging.info("Pure CTC mode")
elif args.mtlalpha == 0.0:
mtl_mode = "att"
logging.info("Pure attention mode")
else:
mtl_mode = "mtl"
logging.info("Multitask learning mode")
# specify model architecture
logging.info("import model module: " + args.model_module)
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args, flag_return=False)
assert isinstance(model, ASRInterface)
total_subsampling_factor = model.get_total_subsampling_factor()
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
# Set gpu
ngpu = args.ngpu
if ngpu == 1:
gpu_id = 0
# Make a specified GPU current
chainer.cuda.get_device_from_id(gpu_id).use()
model.to_gpu() # Copy the model to the GPU
logging.info("single gpu calculation.")
elif ngpu > 1:
gpu_id = 0
devices = {"main": gpu_id}
for gid in range(1, ngpu):
devices["sub_%d" % gid] = gid
logging.info("multi gpu calculation (#gpus = %d)." % ngpu)
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
else:
gpu_id = -1
logging.info("cpu calculation")
# Setup an optimizer
if args.opt == "adadelta":
optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
elif args.opt == "adam":
optimizer = chainer.optimizers.Adam()
elif args.opt == "noam":
optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9)
else:
raise NotImplementedError("args.opt={}".format(args.opt))
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))
# Setup a converter
converter = model.custom_converter(subsampling_factor=model.subsample[0])
# read json data
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
# set up training iterator and updater
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": True}, # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
accum_grad = args.accum_grad
if ngpu <= 1:
# make minibatch list (variable length)
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
# hack to make batchsize argument as 1
# actual batchsize is included in a list
if args.n_iter_processes > 0:
train_iters = [
ToggleableShufflingMultiprocessIterator(
TransformDataset(train, load_tr),
batch_size=1,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
shuffle=not use_sortagrad,
)
]
else:
train_iters = [
ToggleableShufflingSerialIterator(
TransformDataset(train, load_tr),
batch_size=1,
shuffle=not use_sortagrad,
)
]
# set up updater
updater = model.custom_updater(
train_iters[0],
optimizer,
converter=converter,
device=gpu_id,
accum_grad=accum_grad,
)
else:
if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
raise NotImplementedError(
"--batch-count 'bin' and 'frame' are not implemented "
"in chainer multi gpu"
)
# set up minibatches
train_subsets = []
for gid in range(ngpu):
# make subset
train_json_subset = {
k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid
}
# make minibatch list (variable length)
train_subsets += [
make_batchset(
train_json_subset,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
)
]
# each subset must have same length for MultiprocessParallelUpdater
maxlen = max([len(train_subset) for train_subset in train_subsets])
for train_subset in train_subsets:
if maxlen != len(train_subset):
for i in range(maxlen - len(train_subset)):
train_subset += [train_subset[i]]
# hack to make batchsize argument as 1
# actual batchsize is included in a list
if args.n_iter_processes > 0:
train_iters = [
ToggleableShufflingMultiprocessIterator(
TransformDataset(train_subsets[gid], load_tr),
batch_size=1,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
shuffle=not use_sortagrad,
)
for gid in range(ngpu)
]
else:
train_iters = [
ToggleableShufflingSerialIterator(
TransformDataset(train_subsets[gid], load_tr),
batch_size=1,
shuffle=not use_sortagrad,
)
for gid in range(ngpu)
]
# set up updater
updater = model.custom_parallel_updater(
train_iters, optimizer, converter=converter, devices=devices
)
# Set up a trainer
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
if use_sortagrad:
trainer.extend(
ShufflingEnabler(train_iters),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
if args.opt == "noam":
from espnet.nets.chainer_backend.transformer.training import VaswaniRule
trainer.extend(
VaswaniRule(
"alpha",
d=args.adim,
warmup_steps=args.transformer_warmup_steps,
scale=args.transformer_lr,
),
trigger=(1, "iteration"),
)
# Resume from a snapshot
if args.resume:
chainer.serializers.load_npz(args.resume, trainer)
# set up validation iterator
valid = make_batchset(
valid_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
if args.n_iter_processes > 0:
valid_iter = chainer.iterators.MultiprocessIterator(
TransformDataset(valid, load_cv),
batch_size=1,
repeat=False,
shuffle=False,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
)
else:
valid_iter = chainer.iterators.SerialIterator(
TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False
)
# Evaluate the model with the test dataset for each epoch
trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))
# Save attention weight each epoch
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["input"][0]["shape"][1]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
logging.info("Using custom PlotAttentionReport")
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=gpu_id,
subsampling_factor=total_subsampling_factor,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Take a snapshot for each specified epoch
trainer.extend(
extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"),
trigger=(1, "epoch"),
)
# Make a plot for training and validation values
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att",
],
"epoch",
file_name="loss.png",
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
# Save best models
trainer.extend(
extensions.snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
if mtl_mode != "ctc":
trainer.extend(
extensions.snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc" and mtl_mode != "ctc":
trainer.extend(
restore_snapshot(model, args.outdir + "/model.acc.best"),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(model, args.outdir + "/model.loss.best"),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_ctc",
"main/loss_att",
"validation/main/loss",
"validation/main/loss_ctc",
"validation/main/loss_att",
"main/acc",
"validation/main/acc",
"elapsed_time",
]
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps", lambda trainer: trainer.updater.get_optimizer("main").eps
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
try:
from tensorboardX import SummaryWriter
except Exception:
logging.error("Please install tensorboardx")
raise
writer = SummaryWriter(args.tensorboard_dir)
trainer.extend(
TensorboardLogger(writer, att_reporter),
trigger=(args.report_interval_iters, "iteration"),
)
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
# display chainer version
logging.info("chainer version = " + chainer.__version__)
set_deterministic_chainer(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
# specify model architecture
logging.info("reading model parameters from " + args.model)
# To be compatible with v.0.3.0 models
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.chainer_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
chainer_load(args.model, model)
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_chainer.ClassifierWithState(
lm_chainer.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit
)
)
chainer_load(args.rnnlm, rnnlm)
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_chainer.ClassifierWithState(
lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
)
chainer_load(args.word_rnnlm, word_rnnlm)
if rnnlm is not None:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.MultiLevelLM(
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
)
)
else:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.LookAheadWordLM(
word_rnnlm.predictor, word_dict, char_dict
)
)
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
# decode each utterance
new_js = {}
with chainer.no_backprop_mode():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Training/decoding definition for the speech recognition task."""
import copy
import itertools
import json
import logging
import math
import os
import numpy as np
import torch
import torch.distributed as dist
from chainer import reporter as reporter_module
from chainer import training
from chainer.training import extensions
from chainer.training.updater import StandardUpdater
from packaging.version import parse as V
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.parallel import data_parallel
from torch.utils.data.distributed import DistributedSampler
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.asr.asr_utils import (
CompareValueTrigger,
adadelta_eps_decay,
add_results_to_json,
format_mulenc_args,
get_model_conf,
plot_spectrogram,
restore_snapshot,
snapshot_object,
torch_load,
torch_resume,
torch_snapshot,
)
from espnet.asr.pytorch_backend.asr_init import (
freeze_modules,
load_trained_model,
load_trained_modules,
)
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.beam_search_transducer import BeamSearchTransducer
from espnet.nets.pytorch_backend.e2e_asr import pad_list
from espnet.nets.pytorch_backend.streaming.segment import SegmentStreamingE2E
from espnet.nets.pytorch_backend.streaming.window import WindowStreamingE2E
from espnet.transform.spectrogram import IStft
from espnet.transform.transformation import Transformation
from espnet.utils.cli_writers import file_writer_helper
from espnet.utils.dataset import ChainerDataLoader, Transform, TransformDataset
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop
def _recursive_to(xs, device):
if torch.is_tensor(xs):
return xs.to(device)
if isinstance(xs, tuple):
return tuple(_recursive_to(x, device) for x in xs)
return xs
class DistributedDictSummary:
"""Distributed version of DictSummary.
This implementation is based on an official implementation below.
https://github.com/chainer/chainer/blob/v6.7.0/chainer/reporter.py
To gather stats information from all processes and calculate exact mean values,
this class is running AllReduce operation in compute_mean().
"""
def __init__(self, device=None):
self._local_summary = reporter_module.DictSummary()
self._summary_names = None
self._device = device
def add(self, d):
if self._summary_names is None:
# This assumes that `d` always includes the same name list,
# and the name list is identical accross all processes.
self._summary_names = frozenset(d.keys())
return self._local_summary.add(d)
def compute_mean(self):
# Even if `self._local_summary` doesn't have a few keys
# due to invalid observations like NaN, zero, etc,
# `raw_values` can properly these entries
# thanks to zero as an initial value.
raw_values = {name: [0.0, 0] for name in self._summary_names}
for name, summary in self._local_summary._summaries.items():
raw_values[name][0] += summary._x
raw_values[name][1] += summary._n
sum_list = []
count_list = []
for name in sorted(self._summary_names):
sum_list.append(raw_values[name][0])
count_list.append(raw_values[name][1])
sum_tensor = torch.tensor(sum_list, device=self._device)
count_tensor = torch.tensor(count_list, device=self._device)
# AllReduce both of sum and count in parallel.
sum_handle = dist.all_reduce(sum_tensor, async_op=True)
count_handle = dist.all_reduce(count_tensor, async_op=True)
sum_handle.wait()
count_handle.wait()
# Once both ops are enqueued, putting an op to calculate actual average value.
mean_tensor = sum_tensor / count_tensor
result_dict = {}
for idx, name in enumerate(sorted(self._summary_names)):
if name not in self._local_summary._summaries:
# If an entry with a target name doesn't exist in `self._local_summary`,
# this entry must be removed from `result_dict`.
# This behavior is the same with original DictSummary.
continue
result_dict[name] = mean_tensor[idx].item()
return result_dict
class CustomEvaluator(BaseEvaluator):
"""Custom Evaluator for Pytorch.
Args:
model (torch.nn.Module): The model to evaluate.
iterator (chainer.dataset.Iterator) : The train iterator.
target (link | dict[str, link]) :Link object or a dictionary of
links to evaluate. If this is just a link object, the link is
registered by the name ``'main'``.
device (torch.device): The device used.
ngpu (int): The number of GPUs.
use_ddp (bool): The flag to use DDP.
"""
def __init__(self, model, iterator, target, device, ngpu=None, use_ddp=False):
super(CustomEvaluator, self).__init__(iterator, target)
self.model = model
self.device = device
if ngpu is not None:
self.ngpu = ngpu
elif device.type == "cpu":
self.ngpu = 0
else:
self.ngpu = 1
self.use_ddp = use_ddp
# The core part of the update routine can be customized by overriding
def evaluate(self):
"""Main evaluate routine for CustomEvaluator."""
iterator = self._iterators["main"]
if self.eval_hook:
self.eval_hook(self)
if hasattr(iterator, "reset"):
iterator.reset()
it = iterator
else:
it = copy.copy(iterator)
if self.use_ddp:
summary = DistributedDictSummary(self.device)
else:
summary = reporter_module.DictSummary()
self.model.eval()
with torch.no_grad():
for batch in it:
x = _recursive_to(batch, self.device)
observation = {}
with reporter_module.report_scope(observation):
# read scp files
# x: original json with loaded features
# will be converted to chainer variable later
if self.ngpu == 0 or self.use_ddp:
self.model(*x)
else:
# apex does not support torch.nn.DataParallel
data_parallel(self.model, x, range(self.ngpu))
summary.add(observation)
self.model.train()
return summary.compute_mean()
class CustomUpdater(StandardUpdater):
"""Custom Updater for Pytorch.
Args:
model (torch.nn.Module): The model to update.
grad_clip_threshold (float): The gradient clipping value to use.
train_iter (chainer.dataset.Iterator): The training iterator.
optimizer (torch.optim.optimizer): The training optimizer.
device (torch.device): The device to use.
ngpu (int): The number of gpus to use.
use_apex (bool): The flag to use Apex in backprop.
use_ddp (bool): The flag to use DDP for multi-GPU training.
"""
def __init__(
self,
model,
grad_clip_threshold,
train_iter,
optimizer,
device,
ngpu,
grad_noise=False,
accum_grad=1,
use_apex=False,
use_ddp=False,
):
super(CustomUpdater, self).__init__(train_iter, optimizer)
self.model = model
self.grad_clip_threshold = grad_clip_threshold
self.device = device
self.ngpu = ngpu
self.accum_grad = accum_grad
self.forward_count = 0
self.grad_noise = grad_noise
self.iteration = 0
self.use_apex = use_apex
self.use_ddp = use_ddp
# The core part of the update routine can be customized by overriding.
def update_core(self):
"""Main update routine of the CustomUpdater."""
# When we pass one iterator and optimizer to StandardUpdater.__init__,
# they are automatically named 'main'.
train_iter = self.get_iterator("main")
optimizer = self.get_optimizer("main")
epoch = train_iter.epoch
# Get the next batch (a list of json files)
batch = train_iter.next()
# self.iteration += 1 # Increase may result in early report,
# which is done in other place automatically.
x = _recursive_to(batch, self.device)
is_new_epoch = train_iter.epoch != epoch
# When the last minibatch in the current epoch is given,
# gradient accumulation is turned off in order to evaluate the model
# on the validation set in every epoch.
# see details in https://github.com/espnet/espnet/pull/1388
# Compute the loss at this time step and accumulate it
if self.ngpu == 0 or self.use_ddp:
loss = self.model(*x).mean() / self.accum_grad
else:
# apex does not support torch.nn.DataParallel
loss = (
data_parallel(self.model, x, range(self.ngpu)).mean() / self.accum_grad
)
if self.use_apex:
from apex import amp
# NOTE: for a compatibility with noam optimizer
opt = optimizer.optimizer if hasattr(optimizer, "optimizer") else optimizer
with amp.scale_loss(loss, opt) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()
# gradient noise injection
if self.grad_noise:
from espnet.asr.asr_utils import add_gradient_noise
add_gradient_noise(
self.model, self.iteration, duration=100, eta=1.0, scale_factor=0.55
)
# update parameters
self.forward_count += 1
if not is_new_epoch and self.forward_count != self.accum_grad:
return
self.forward_count = 0
# compute the gradient norm to check if it is normal or not
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.grad_clip_threshold
)
if self.use_ddp:
# NOTE: assuming gradients have not been reduced yet here.
# Try to gather the norm of gradients from all workers,
# and calculate average grad norm.
dist.all_reduce(grad_norm)
logging.info("grad norm={}".format(grad_norm))
if math.isnan(grad_norm):
logging.warning("grad norm is nan. Do not update model.")
else:
optimizer.step()
optimizer.zero_grad()
def update(self):
self.update_core()
# #iterations with accum_grad > 1
# Ref.: https://github.com/espnet/espnet/issues/777
if self.forward_count == 0:
self.iteration += 1
class CustomConverter(object):
"""Custom batch converter for Pytorch.
Args:
subsampling_factor (int): The subsampling factor.
dtype (torch.dtype): Data type to convert.
"""
def __init__(self, subsampling_factor=1, dtype=torch.float32):
"""Construct a CustomConverter object."""
self.subsampling_factor = subsampling_factor
self.ignore_id = -1
self.dtype = dtype
def __call__(self, batch, device=torch.device("cpu")):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor)
"""
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0]
# perform subsampling
if self.subsampling_factor > 1:
xs = [x[:: self.subsampling_factor, :] for x in xs]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
# currently only support real number
if xs[0].dtype.kind == "c":
xs_pad_real = pad_list(
[torch.from_numpy(x.real).float() for x in xs], 0
).to(device, dtype=self.dtype)
xs_pad_imag = pad_list(
[torch.from_numpy(x.imag).float() for x in xs], 0
).to(device, dtype=self.dtype)
# Note(kamo):
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
# Don't create ComplexTensor and give it E2E here
# because torch.nn.DataParellel can't handle it.
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
else:
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
device, dtype=self.dtype
)
ilens = torch.from_numpy(ilens).to(device)
# NOTE: this is for multi-output (e.g., speech translation)
ys_pad = pad_list(
[
torch.from_numpy(
np.array(y[0][:]) if isinstance(y, tuple) else y
).long()
for y in ys
],
self.ignore_id,
).to(device)
return xs_pad, ilens, ys_pad
class CustomConverterMulEnc(object):
"""Custom batch converter for Pytorch in multi-encoder case.
Args:
subsampling_factors (list): List of subsampling factors for each encoder.
dtype (torch.dtype): Data type to convert.
"""
def __init__(self, subsampling_factors=[1, 1], dtype=torch.float32):
"""Initialize the converter."""
self.subsampling_factors = subsampling_factors
self.ignore_id = -1
self.dtype = dtype
self.num_encs = len(subsampling_factors)
def __call__(self, batch, device=torch.device("cpu")):
"""Transform a batch and send it to a device.
Args:
batch (list): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple( list(torch.Tensor), list(torch.Tensor), torch.Tensor)
"""
# batch should be located in list
assert len(batch) == 1
xs_list = batch[0][: self.num_encs]
ys = batch[0][-1]
# perform subsampling
if np.sum(self.subsampling_factors) > self.num_encs:
xs_list = [
[x[:: self.subsampling_factors[i], :] for x in xs_list[i]]
for i in range(self.num_encs)
]
# get batch of lengths of input sequences
ilens_list = [
np.array([x.shape[0] for x in xs_list[i]]) for i in range(self.num_encs)
]
# perform padding and convert to tensor
# currently only support real number
xs_list_pad = [
pad_list([torch.from_numpy(x).float() for x in xs_list[i]], 0).to(
device, dtype=self.dtype
)
for i in range(self.num_encs)
]
ilens_list = [
torch.from_numpy(ilens_list[i]).to(device) for i in range(self.num_encs)
]
# NOTE: this is for multi-task learning (e.g., speech translation)
ys_pad = pad_list(
[
torch.from_numpy(np.array(y[0]) if isinstance(y, tuple) else y).long()
for y in ys
],
self.ignore_id,
).to(device)
return xs_list_pad, ilens_list, ys_pad
def is_writable_process(args, worldsize, rank, localrank):
return not args.use_ddp or rank == 0
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
if args.use_ddp:
# initialize distributed environment.
# NOTE: current implementation supports
# only single-node training.
# get process information.
worldsize = os.environ.get("WORLD_SIZE", None)
assert worldsize is not None
worldsize = int(worldsize)
assert worldsize == args.ngpu
rank = os.environ.get("RANK", None)
assert rank is not None
rank = int(rank)
localrank = os.environ.get("LOCAL_RANK", None)
assert localrank is not None
localrank = int(localrank)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=worldsize,
)
if rank != 0:
# Disable all logs in non-master process.
logging.disable()
else:
worldsize = 1
rank = 0
localrank = 0
set_deterministic_pytorch(args)
if args.num_encs > 1:
args = format_mulenc_args(args)
# check cuda availability
if not torch.cuda.is_available():
logging.warning("cuda is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim_list = [
int(valid_json[utts[0]]["input"][i]["shape"][-1]) for i in range(args.num_encs)
]
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
for i in range(args.num_encs):
logging.info("stream{}: input dims : {}".format(i + 1, idim_list[i]))
logging.info("#output dims: " + str(odim))
# specify attention, CTC, hybrid mode
if "transducer" in args.model_module:
if (
getattr(args, "etype", False) == "custom"
or getattr(args, "dtype", False) == "custom"
):
mtl_mode = "custom_transducer"
else:
mtl_mode = "transducer"
logging.info("Pure transducer mode")
elif args.mtlalpha == 1.0:
mtl_mode = "ctc"
logging.info("Pure CTC mode")
elif args.mtlalpha == 0.0:
mtl_mode = "att"
logging.info("Pure attention mode")
else:
mtl_mode = "mtl"
logging.info("Multitask learning mode")
if (args.enc_init is not None or args.dec_init is not None) and args.num_encs == 1:
model = load_trained_modules(idim_list[0], odim, args)
else:
model_class = dynamic_import(args.model_module)
model = model_class(
idim_list[0] if args.num_encs == 1 else idim_list, odim, args
)
assert isinstance(model, ASRInterface)
total_subsampling_factor = model.get_total_subsampling_factor()
logging.info(
" Total parameter of the model = "
+ str(sum(p.numel() for p in model.parameters()))
)
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(len(args.char_list), rnnlm_args.layer, rnnlm_args.unit)
)
torch_load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
if is_writable_process(args, worldsize, rank, localrank):
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(
idim_list[0] if args.num_encs == 1 else idim_list,
odim,
vars(args),
),
indent=4,
ensure_ascii=False,
sort_keys=True,
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
reporter = model.reporter
if args.use_ddp:
if args.num_encs > 1:
# TODO(ruizhili): implement data parallel for multi-encoder setup.
raise NotImplementedError(
"Data parallel is not supported for multi-encoder setup."
)
else:
# check the use of multi-gpu
if args.ngpu > 1:
if args.batch_size != 0:
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
args.batch_size *= args.ngpu
if args.num_encs > 1:
# TODO(ruizhili): implement data parallel for multi-encoder setup.
raise NotImplementedError(
"Data parallel is not supported for multi-encoder setup."
)
# set torch device
if args.use_ddp:
device = torch.device(f"cuda:{localrank}")
else:
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model.to(device=device, dtype=dtype)
if args.freeze_mods:
model, model_params = freeze_modules(model, args.freeze_mods)
else:
model_params = model.parameters()
logging.warning(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(p.numel() for p in model.parameters() if p.requires_grad)
* 100.0
/ sum(p.numel() for p in model.parameters()),
)
)
# Setup an optimizer
if args.opt == "adadelta":
optimizer = torch.optim.Adadelta(
model_params, rho=0.95, eps=args.eps, weight_decay=args.weight_decay
)
elif args.opt == "adam":
optimizer = torch.optim.Adam(model_params, weight_decay=args.weight_decay)
elif args.opt == "noam":
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
if "transducer" in mtl_mode:
if args.noam_adim > 0:
optimizer = get_std_opt(
model_params,
args.noam_adim,
args.optimizer_warmup_steps,
args.noam_lr,
)
else:
raise ValueError("noam-adim option should be set to use Noam scheduler")
else:
optimizer = get_std_opt(
model_params,
args.adim,
args.transformer_warmup_steps,
args.transformer_lr,
)
else:
raise NotImplementedError("unknown optimizer: " + args.opt)
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(
f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux"
)
raise e
if args.opt == "noam":
model, optimizer.optimizer = amp.initialize(
model, optimizer.optimizer, opt_level=args.train_dtype
)
else:
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.train_dtype
)
use_apex = True
from espnet.nets.pytorch_backend.ctc import CTC
amp.register_float_function(CTC, "loss_fn")
amp.init()
logging.warning("register ctc as float function")
else:
use_apex = False
# FIXME: TOO DIRTY HACK
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
# Setup a converter
if args.num_encs == 1:
converter = CustomConverter(subsampling_factor=model.subsample[0], dtype=dtype)
else:
converter = CustomConverterMulEnc(
[i[0] for i in model.subsample_list], dtype=dtype
)
# read json data
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
if args.use_ddp:
# When using DDP, minimum batch size for each process is 1.
min_batch_size = 1
else:
min_batch_size = args.ngpu if args.ngpu > 1 else 1
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=min_batch_size,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
valid = make_batchset(
valid_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=min_batch_size,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": True}, # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
train_ds = TransformDataset(train, Transform(converter, load_tr))
val_ds = TransformDataset(valid, Transform(converter, load_cv))
train_sampler = None
val_sampler = None
shuffle = not use_sortagrad
if args.use_ddp:
train_sampler = DistributedSampler(train_ds)
val_sampler = DistributedSampler(val_ds)
shuffle = False
train_iter = ChainerDataLoader(
dataset=train_ds,
batch_size=1,
num_workers=args.n_iter_processes,
shuffle=shuffle,
sampler=train_sampler,
collate_fn=ChainerDataLoader.get_first_element,
)
valid_iter = ChainerDataLoader(
dataset=val_ds,
batch_size=1,
shuffle=False,
sampler=val_sampler,
collate_fn=ChainerDataLoader.get_first_element,
num_workers=args.n_iter_processes,
)
# Set up a trainer
if args.use_ddp:
model = DDP(model, device_ids=[localrank])
updater = CustomUpdater(
model,
args.grad_clip,
{"main": train_iter},
optimizer,
device,
args.ngpu,
args.grad_noise,
args.accum_grad,
use_apex=use_apex,
use_ddp=args.use_ddp,
)
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
# call DistributedSampler.set_epoch at begining of each epoch.
if args.use_ddp:
@training.make_extension(trigger=(1, "epoch"))
def set_epoch_to_distributed_sampler(trainer):
# NOTE: at the first time when this fuction is called,
# `sampler.epoch` should be 0, and a given trainer object
# has 1 as a `trainer.updater.epoch`.
# This means that, in the first epoch,
# dataset is shuffled with random seed and a value 0,
# and, in the second epoch, dataset is shuffled
# with the same random seed and a value 1.
#
# See a link below for more details.
# https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler
train_sampler.set_epoch(trainer.updater.epoch)
val_sampler.set_epoch(trainer.updater.epoch)
trainer.extend(set_epoch_to_distributed_sampler)
if use_sortagrad:
trainer.extend(
ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
# Resume from a snapshot
if args.resume:
logging.info("resumed from %s" % args.resume)
torch_resume(args.resume, trainer)
# Evaluate the model with the test dataset for each epoch
if args.save_interval_iters > 0:
trainer.extend(
CustomEvaluator(
model, {"main": valid_iter}, reporter, device, args.ngpu, args.use_ddp
),
trigger=(args.save_interval_iters, "iteration"),
)
else:
trainer.extend(
CustomEvaluator(
model, {"main": valid_iter}, reporter, device, args.ngpu, args.use_ddp
)
)
if is_writable_process(args, worldsize, rank, localrank):
# Save attention weight each epoch
is_attn_plot = (
"transformer" in args.model_module
or "conformer" in args.model_module
or mtl_mode in ["att", "mtl", "custom_transducer"]
)
if args.num_save_attention > 0 and is_attn_plot:
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["input"][0]["shape"][1]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=device,
subsampling_factor=total_subsampling_factor,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Save CTC prob at each epoch
if mtl_mode in ["ctc", "mtl"] and args.num_save_ctc > 0:
# NOTE: sort it by output lengths
data = sorted(
list(valid_json.items())[: args.num_save_ctc],
key=lambda x: int(x[1]["output"][0]["shape"][0]),
reverse=True,
)
if hasattr(model, "module"):
ctc_vis_fn = model.module.calculate_all_ctc_probs
plot_class = model.module.ctc_plot_class
else:
ctc_vis_fn = model.calculate_all_ctc_probs
plot_class = model.ctc_plot_class
ctc_reporter = plot_class(
ctc_vis_fn,
data,
args.outdir + "/ctc_prob",
converter=converter,
transform=load_cv,
device=device,
subsampling_factor=total_subsampling_factor,
)
trainer.extend(ctc_reporter, trigger=(1, "epoch"))
else:
ctc_reporter = None
# Make a plot for training and validation values
if args.num_encs > 1:
report_keys_loss_ctc = [
"main/loss_ctc{}".format(i + 1) for i in range(model.num_encs)
] + [
"validation/main/loss_ctc{}".format(i + 1)
for i in range(model.num_encs)
]
report_keys_cer_ctc = [
"main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
] + [
"validation/main/cer_ctc{}".format(i + 1) for i in range(model.num_encs)
]
if hasattr(model, "is_transducer"):
trans_keys = [
"main/loss",
"validation/main/loss",
"main/loss_trans",
"validation/main/loss_trans",
]
ctc_keys = (
["main/loss_ctc", "validation/main/loss_ctc"]
if args.use_ctc_loss
else []
)
aux_trans_keys = (
[
"main/loss_aux_trans",
"validation/main/loss_aux_trans",
]
if args.use_aux_transducer_loss
else []
)
symm_kl_div_keys = (
[
"main/loss_symm_kl_div",
"validation/main/loss_symm_kl_div",
]
if args.use_symm_kl_div_loss
else []
)
lm_keys = (
[
"main/loss_lm",
"validation/main/loss_lm",
]
if args.use_lm_loss
else []
)
transducer_keys = (
trans_keys + ctc_keys + aux_trans_keys + symm_kl_div_keys + lm_keys
)
trainer.extend(
extensions.PlotReport(
transducer_keys,
"epoch",
file_name="loss.png",
)
)
else:
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att",
]
+ ([] if args.num_encs == 1 else report_keys_loss_ctc),
"epoch",
file_name="loss.png",
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
trainer.extend(
extensions.PlotReport(
["main/cer_ctc", "validation/main/cer_ctc"]
+ ([] if args.num_encs == 1 else report_keys_loss_ctc),
"epoch",
file_name="cer.png",
)
)
# Save best models
trainer.extend(
snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
if mtl_mode not in ["ctc", "transducer", "custom_transducer"]:
trainer.extend(
snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# save snapshot which contains model and optimizer states
if args.save_interval_iters > 0:
trainer.extend(
torch_snapshot(filename="snapshot.iter.{.updater.iteration}"),
trigger=(args.save_interval_iters, "iteration"),
)
# save snapshot at every epoch - for model averaging
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc" and mtl_mode != "ctc":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.acc.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.loss.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# NOTE: In some cases, it may take more than one epoch for the model's loss
# to escape from a local minimum.
# Thus, restore_snapshot extension is not used here.
# see details in https://github.com/espnet/espnet/pull/2171
elif args.criterion == "loss_eps_decay_only":
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
if is_writable_process(args, worldsize, rank, localrank):
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
if hasattr(model, "is_transducer"):
report_keys = (
[
"epoch",
"iteration",
]
+ transducer_keys
+ ["elapsed_time"]
)
else:
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_ctc",
"main/loss_att",
"validation/main/loss",
"validation/main/loss_ctc",
"validation/main/loss_att",
"main/acc",
"validation/main/acc",
"main/cer_ctc",
"validation/main/cer_ctc",
"elapsed_time",
] + (
[] if args.num_encs == 1 else report_keys_cer_ctc + report_keys_loss_ctc
)
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps",
lambda trainer: trainer.updater.get_optimizer("main").param_groups[
0
]["eps"],
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
if args.report_cer:
report_keys.append("validation/main/cer")
if args.report_wer:
report_keys.append("validation/main/wer")
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(
extensions.ProgressBar(update_interval=args.report_interval_iters)
)
set_early_stop(trainer, args)
if is_writable_process(args, worldsize, rank, localrank):
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
from torch.utils.tensorboard import SummaryWriter
trainer.extend(
TensorboardLogger(
SummaryWriter(args.tensorboard_dir),
att_reporter=att_reporter,
ctc_reporter=ctc_reporter,
),
trigger=(args.report_interval_iters, "iteration"),
)
if args.use_ddp:
# To avoid busy wait on non-main processes
# during a main process is writing plot, logs, etc,
# one additional extension must be added at the last.
# Within this additional extension,
# a main process will send a notification tensor
# to other processes when the main process finishes
# all operations like writing plot, log, etc.
src_rank = 0 # TODO(lazykyama): removing hard-coded value.
@training.make_extension(trigger=(1, "epoch"))
def barrier_extension_per_epoch(trainer):
notification = torch.zeros(1, device=device)
dist.broadcast(notification, src=src_rank)
torch.cuda.synchronize(device=device)
trainer.extend(barrier_extension_per_epoch)
# Run the training
trainer.run()
if is_writable_process(args, worldsize, rank, localrank):
check_early_stop(trainer, args.epochs)
def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model, training=False)
assert isinstance(model, ASRInterface)
model.recog_args = args
if args.quantize_config is not None:
q_config = set([getattr(torch.nn, q) for q in args.quantize_config])
else:
q_config = {torch.nn.Linear}
if args.quantize_asr_model:
logging.info("Use a quantized ASR model for decoding.")
# It seems quantized LSTM only supports non-packed sequence before torch 1.4.0.
# Reference issue: https://github.com/pytorch/pytorch/issues/27963
if (
V(torch.__version__) < V("1.4.0")
and "lstm" in train_args.etype
and torch.nn.LSTM in q_config
):
raise ValueError(
"Quantized LSTM in ESPnet is only supported with torch 1.4+."
)
# Dunno why but weight_observer from dynamic quantized module must have
# dtype=torch.qint8 with torch < 1.5 although dtype=torch.float16 is supported.
if args.quantize_dtype == "float16" and V(torch.__version__) < V("1.5.0"):
raise ValueError(
"float16 dtype for dynamic quantization is not supported with torch "
"version < 1.5.0. Switching to qint8 dtype instead."
)
dtype = getattr(torch, args.quantize_dtype)
model = torch.quantization.quantize_dynamic(model, q_config, dtype=dtype)
if args.streaming_mode and "transformer" in train_args.model_module:
raise NotImplementedError("streaming mode for transformer is not implemented")
logging.info(
" Total parameter of the model = "
+ str(sum(p.numel() for p in model.parameters()))
)
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
if getattr(rnnlm_args, "model_module", "default") != "default":
raise ValueError(
"use '--api v2' option to decode with non-default language model"
)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
torch_load(args.rnnlm, rnnlm)
if args.quantize_lm_model:
dtype = getattr(torch, args.quantize_dtype)
rnnlm = torch.quantization.quantize_dynamic(rnnlm, q_config, dtype=dtype)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(word_dict),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
)
)
else:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.LookAheadWordLM(
word_rnnlm.predictor, word_dict, char_dict
)
)
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info("gpu id: " + str(gpu_id))
model.cuda()
if rnnlm:
rnnlm.cuda()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
# load transducer beam search
if hasattr(model, "is_transducer"):
if hasattr(model, "dec"):
trans_decoder = model.dec
else:
trans_decoder = model.decoder
joint_network = model.transducer_tasks.joint_network
beam_search_transducer = BeamSearchTransducer(
decoder=trans_decoder,
joint_network=joint_network,
beam_size=args.beam_size,
lm=rnnlm,
lm_weight=args.lm_weight,
search_type=args.search_type,
max_sym_exp=args.max_sym_exp,
u_max=args.u_max,
nstep=args.nstep,
prefix_alpha=args.prefix_alpha,
expansion_gamma=args.expansion_gamma,
expansion_beta=args.expansion_beta,
score_norm=args.score_norm,
softmax_temperature=args.softmax_temperature,
nbest=args.nbest,
quantization=args.quantize_asr_model,
)
if args.batchsize == 0:
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)
feat = (
feat[0][0]
if args.num_encs == 1
else [feat[idx][0] for idx in range(model.num_encs)]
)
if args.streaming_mode == "window" and args.num_encs == 1:
logging.info(
"Using streaming recognizer with window size %d frames",
args.streaming_window,
)
se2e = WindowStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
for i in range(0, feat.shape[0], args.streaming_window):
logging.info(
"Feeding frames %d - %d", i, i + args.streaming_window
)
se2e.accept_input(feat[i : i + args.streaming_window])
logging.info("Running offline attention decoder")
se2e.decode_with_attention_offline()
logging.info("Offline attention decoder finished")
nbest_hyps = se2e.retrieve_recognition()
elif args.streaming_mode == "segment" and args.num_encs == 1:
logging.info(
"Using streaming recognizer with threshold value %d",
args.streaming_min_blank_dur,
)
nbest_hyps = []
for n in range(args.nbest):
nbest_hyps.append({"yseq": [], "score": 0.0})
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
r = np.prod(model.subsample)
for i in range(0, feat.shape[0], r):
hyps = se2e.accept_input(feat[i : i + r])
if hyps is not None:
text = "".join(
[
train_args.char_list[int(x)]
for x in hyps[0]["yseq"][1:-1]
if int(x) != -1
]
)
text = text.replace(
"\u2581", " "
).strip() # for SentencePiece
text = text.replace(model.space, " ")
text = text.replace(model.blank, "")
logging.info(text)
for n in range(args.nbest):
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
nbest_hyps[n]["score"] += hyps[n]["score"]
elif hasattr(model, "is_transducer"):
nbest_hyps = model.recognize(feat, beam_search_transducer)
else:
nbest_hyps = model.recognize(
feat, args, train_args.char_list, rnnlm
)
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
else:
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return itertools.zip_longest(*kargs, fillvalue=fillvalue)
# sort data if batchsize > 1
keys = list(js.keys())
if args.batchsize > 1:
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
with torch.no_grad():
for names in grouper(args.batchsize, keys, None):
names = [name for name in names if name]
batch = [(name, js[name]) for name in names]
feats = (
load_inputs_and_targets(batch)[0]
if args.num_encs == 1
else load_inputs_and_targets(batch)
)
if args.streaming_mode == "window" and args.num_encs == 1:
raise NotImplementedError
elif args.streaming_mode == "segment" and args.num_encs == 1:
if args.batchsize > 1:
raise NotImplementedError
feat = feats[0]
nbest_hyps = []
for n in range(args.nbest):
nbest_hyps.append({"yseq": [], "score": 0.0})
se2e = SegmentStreamingE2E(e2e=model, recog_args=args, rnnlm=rnnlm)
r = np.prod(model.subsample)
for i in range(0, feat.shape[0], r):
hyps = se2e.accept_input(feat[i : i + r])
if hyps is not None:
text = "".join(
[
train_args.char_list[int(x)]
for x in hyps[0]["yseq"][1:-1]
if int(x) != -1
]
)
text = text.replace(
"\u2581", " "
).strip() # for SentencePiece
text = text.replace(model.space, " ")
text = text.replace(model.blank, "")
logging.info(text)
for n in range(args.nbest):
nbest_hyps[n]["yseq"].extend(hyps[n]["yseq"])
nbest_hyps[n]["score"] += hyps[n]["score"]
nbest_hyps = [nbest_hyps]
else:
nbest_hyps = model.recognize_batch(
feats, args, train_args.char_list, rnnlm=rnnlm
)
for i, nbest_hyp in enumerate(nbest_hyps):
name = names[i]
new_js[name] = add_results_to_json(
js[name], nbest_hyp, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
def enhance(args):
"""Dumping enhanced speech and mask.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
# TODO(ruizhili): implement enhance for multi-encoder model
assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format(
args.num_encs
)
# load trained model parameters
logging.info("reading model parameters from " + args.model)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
torch_load(args.model, model)
model.recog_args = args
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info("gpu id: " + str(gpu_id))
model.cuda()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=None, # Apply pre_process in outer func
)
if args.batchsize == 0:
args.batchsize = 1
# Creates writers for outputs from the network
if args.enh_wspecifier is not None:
enh_writer = file_writer_helper(args.enh_wspecifier, filetype=args.enh_filetype)
else:
enh_writer = None
# Creates a Transformation instance
preprocess_conf = (
train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf
)
if preprocess_conf is not None:
logging.info(f"Use preprocessing: {preprocess_conf}")
transform = Transformation(preprocess_conf)
else:
transform = None
# Creates a IStft instance
istft = None
frame_shift = args.istft_n_shift # Used for plot the spectrogram
if args.apply_istft:
if preprocess_conf is not None:
# Read the conffile and find stft setting
with open(preprocess_conf) as f:
# Json format: e.g.
# {"process": [{"type": "stft",
# "win_length": 400,
# "n_fft": 512, "n_shift": 160,
# "window": "han"},
# {"type": "foo", ...}, ...]}
conf = json.load(f)
assert "process" in conf, conf
# Find stft setting
for p in conf["process"]:
if p["type"] == "stft":
istft = IStft(
win_length=p["win_length"],
n_shift=p["n_shift"],
window=p.get("window", "hann"),
)
logging.info(
"stft is found in {}. "
"Setting istft config from it\n{}".format(
preprocess_conf, istft
)
)
frame_shift = p["n_shift"]
break
if istft is None:
# Set from command line arguments
istft = IStft(
win_length=args.istft_win_length,
n_shift=args.istft_n_shift,
window=args.istft_window,
)
logging.info(
"Setting istft config from the command line args\n{}".format(istft)
)
# sort data
keys = list(js.keys())
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return itertools.zip_longest(*kargs, fillvalue=fillvalue)
num_images = 0
if not os.path.exists(args.image_dir):
os.makedirs(args.image_dir)
for names in grouper(args.batchsize, keys, None):
batch = [(name, js[name]) for name in names]
# May be in time region: (Batch, [Time, Channel])
org_feats = load_inputs_and_targets(batch)[0]
if transform is not None:
# May be in time-freq region: : (Batch, [Time, Channel, Freq])
feats = transform(org_feats, train=False)
else:
feats = org_feats
with torch.no_grad():
enhanced, mask, ilens = model.enhance(feats)
for idx, name in enumerate(names):
# Assuming mask, feats : [Batch, Time, Channel. Freq]
# enhanced : [Batch, Time, Freq]
enh = enhanced[idx][: ilens[idx]]
mas = mask[idx][: ilens[idx]]
feat = feats[idx]
# Plot spectrogram
if args.image_dir is not None and num_images < args.num_images:
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
num_images += 1
ref_ch = 0
plt.figure(figsize=(20, 10))
plt.subplot(4, 1, 1)
plt.title("Mask [ref={}ch]".format(ref_ch))
plot_spectrogram(
plt,
mas[:, ref_ch].T,
fs=args.fs,
mode="linear",
frame_shift=frame_shift,
bottom=False,
labelbottom=False,
)
plt.subplot(4, 1, 2)
plt.title("Noisy speech [ref={}ch]".format(ref_ch))
plot_spectrogram(
plt,
feat[:, ref_ch].T,
fs=args.fs,
mode="db",
frame_shift=frame_shift,
bottom=False,
labelbottom=False,
)
plt.subplot(4, 1, 3)
plt.title("Masked speech [ref={}ch]".format(ref_ch))
plot_spectrogram(
plt,
(feat[:, ref_ch] * mas[:, ref_ch]).T,
frame_shift=frame_shift,
fs=args.fs,
mode="db",
bottom=False,
labelbottom=False,
)
plt.subplot(4, 1, 4)
plt.title("Enhanced speech")
plot_spectrogram(
plt, enh.T, fs=args.fs, mode="db", frame_shift=frame_shift
)
plt.savefig(os.path.join(args.image_dir, name + ".png"))
plt.clf()
# Write enhanced wave files
if enh_writer is not None:
if istft is not None:
enh = istft(enh)
else:
enh = enh
if args.keep_length:
if len(org_feats[idx]) < len(enh):
# Truncate the frames added by stft padding
enh = enh[: len(org_feats[idx])]
elif len(org_feats) > len(enh):
padwidth = [(0, (len(org_feats[idx]) - len(enh)))] + [
(0, 0)
] * (enh.ndim - 1)
enh = np.pad(enh, padwidth, mode="constant")
if args.enh_filetype in ("sound", "sound.hdf5"):
enh_writer[name] = (args.fs, enh)
else:
# Hint: To dump stft_signal, mask or etc,
# enh_filetype='hdf5' might be convenient.
enh_writer[name] = enh
if num_images >= args.num_images and enh_writer is None:
logging.info("Breaking the process.")
break
"""Finetuning methods."""
import logging
import os
import re
from collections import OrderedDict
import torch
from espnet.asr.asr_utils import get_model_conf, torch_load
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.mt_interface import MTInterface
from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.dynamic_import import dynamic_import
def freeze_modules(model, modules):
"""Freeze model parameters according to modules list.
Args:
model (torch.nn.Module): Main model.
modules (List): Specified module(s) to freeze.
Return:
model (torch.nn.Module) : Updated main model.
model_params (filter): Filtered model parameters.
"""
for mod, param in model.named_parameters():
if any(mod.startswith(m) for m in modules):
logging.warning(f"Freezing {mod}. It will not be updated during training.")
param.requires_grad = False
model_params = filter(lambda x: x.requires_grad, model.parameters())
return model, model_params
def transfer_verification(model_state_dict, partial_state_dict, modules):
"""Verify tuples (key, shape) for input model modules match specified modules.
Args:
model_state_dict (Dict) : Main model state dict.
partial_state_dict (Dict): Pre-trained model state dict.
modules (List): Specified module(s) to transfer.
Return:
(bool): Whether transfer learning is allowed.
"""
model_modules = []
partial_modules = []
for key_m, value_m in model_state_dict.items():
if any(key_m.startswith(m) for m in modules):
model_modules += [(key_m, value_m.shape)]
model_modules = sorted(model_modules, key=lambda x: (x[0], x[1]))
for key_p, value_p in partial_state_dict.items():
if any(key_p.startswith(m) for m in modules):
partial_modules += [(key_p, value_p.shape)]
partial_modules = sorted(partial_modules, key=lambda x: (x[0], x[1]))
module_match = model_modules == partial_modules
if not module_match:
logging.error(
"Some specified modules from the pre-trained model "
"don't match with the new model modules:"
)
logging.error(f"Pre-trained: {set(partial_modules) - set(model_modules)}")
logging.error(f"New model: {set(model_modules) - set(partial_modules)}")
exit(1)
return module_match
def get_partial_state_dict(model_state_dict, modules):
"""Create state dict with specified modules matching input model modules.
Args:
model_state_dict (Dict): Pre-trained model state dict.
modules (Dict): Specified module(s) to transfer.
Return:
new_state_dict (Dict): State dict with specified modules weights.
"""
new_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
new_state_dict[key] = value
return new_state_dict
def get_lm_state_dict(lm_state_dict):
"""Create compatible ASR decoder state dict from LM state dict.
Args:
lm_state_dict (Dict): Pre-trained LM state dict.
Return:
new_state_dict (Dict): State dict with compatible key names.
"""
new_state_dict = OrderedDict()
for key, value in list(lm_state_dict.items()):
if key == "predictor.embed.weight":
new_state_dict["dec.embed.weight"] = value
elif key.startswith("predictor.rnn."):
_split = key.split(".")
new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0"
new_state_dict[new_key] = value
return new_state_dict
def filter_modules(model_state_dict, modules):
"""Filter non-matched modules in model state dict.
Args:
model_state_dict (Dict): Pre-trained model state dict.
modules (List): Specified module(s) to transfer.
Return:
new_mods (List): Filtered module list.
"""
new_mods = []
incorrect_mods = []
mods_model = list(model_state_dict.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.error(
"Specified module(s) don't match or (partially match) "
f"available modules in model. You specified: {incorrect_mods}."
)
logging.error("The existing modules in model are:")
logging.error(f"{mods_model}")
exit(1)
return new_mods
def create_transducer_compatible_state_dict(
model_state_dict, encoder_type, encoder_units
):
"""Create a compatible transducer model state dict for transfer learning.
If RNN encoder modules from a non-Transducer model are found in
the pre-trained model state dict, the corresponding modules keys are
renamed for compatibility.
Args:
model_state_dict (Dict): Pre-trained model state dict
encoder_type (str): Type of pre-trained encoder.
encoder_units (int): Number of encoder units in pre-trained model.
Returns:
new_state_dict (Dict): Transducer compatible pre-trained model state dict.
"""
if encoder_type.endswith("p") or not encoder_type.endswith(("lstm", "gru")):
return model_state_dict
new_state_dict = OrderedDict()
rnn_key_name = "birnn" if "b" in encoder_type else "rnn"
for key, value in list(model_state_dict.items()):
if any(k in key for k in ["l_last", "nbrnn"]):
if "nbrnn" in key:
layer_name = rnn_key_name + re.search("_l([0-9]+)", key).group(1)
key = re.sub(
"_l([0-9]+)",
"_l0",
key.replace("nbrnn", layer_name),
)
if (encoder_units * 2) == value.size(-1):
value = value[:, :encoder_units] + value[:, encoder_units:]
new_state_dict[key] = value
return new_state_dict
def load_trained_model(model_path, training=True):
"""Load the trained model for recognition.
Args:
model_path (str): Path to model.***.best
training (bool): Training mode specification for transducer model.
Returns:
model (torch.nn.Module): Trained model.
train_args (Namespace): Trained model arguments.
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), "model.json")
)
logging.info(f"Reading model parameters from {model_path}")
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
# CTC Loss is not needed, default to builtin to prevent import errors
if hasattr(train_args, "ctc_type"):
train_args.ctc_type = "builtin"
model_class = dynamic_import(model_module)
if "transducer" in model_module:
model = model_class(idim, odim, train_args, training=training)
custom_torch_load(model_path, model, training=training)
else:
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
def get_trained_model_state_dict(model_path, new_is_transducer):
"""Extract the trained model state dict for pre-initialization.
Args:
model_path (str): Path to trained model.
new_is_transducer (bool): Whether the new model is Transducer-based.
Return:
(Dict): Trained model state dict.
"""
logging.info(f"Reading model parameters from {model_path}")
conf_path = os.path.join(os.path.dirname(model_path), "model.json")
if "rnnlm" in model_path:
return get_lm_state_dict(torch.load(model_path))
idim, odim, args = get_model_conf(model_path, conf_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert (
isinstance(model, MTInterface)
or isinstance(model, ASRInterface)
or isinstance(model, TTSInterface)
)
if new_is_transducer and "transducer" not in args.model_module:
return create_transducer_compatible_state_dict(
model.state_dict(),
args.etype,
args.eunits,
)
return model.state_dict()
def load_trained_modules(idim, odim, args, interface=ASRInterface):
"""Load ASR/MT/TTS model with pre-trained weights for specified modules.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
args Namespace: Model arguments.
interface (ASRInterface|MTInterface|TTSInterface): Model interface.
Return:
main_model (torch.nn.Module): Model with pre-initialized weights.
"""
def print_new_keys(state_dict, modules, model_path):
logging.info(f"Loading {modules} from model: {model_path}")
for k in state_dict.keys():
logging.warning(f"Overriding module {k}")
enc_model_path = args.enc_init
dec_model_path = args.dec_init
enc_modules = args.enc_init_mods
dec_modules = args.dec_init_mods
model_class = dynamic_import(args.model_module)
main_model = model_class(idim, odim, args)
assert isinstance(main_model, interface)
main_state_dict = main_model.state_dict()
logging.warning("Model(s) found for pre-initialization.")
for model_path, modules in [
(enc_model_path, enc_modules),
(dec_model_path, dec_modules),
]:
if model_path is not None:
if os.path.isfile(model_path):
model_state_dict = get_trained_model_state_dict(
model_path, "transducer" in args.model_module
)
modules = filter_modules(model_state_dict, modules)
partial_state_dict = get_partial_state_dict(model_state_dict, modules)
if partial_state_dict:
if transfer_verification(
main_state_dict, partial_state_dict, modules
):
print_new_keys(partial_state_dict, modules, model_path)
main_state_dict.update(partial_state_dict)
else:
logging.error(f"Specified model was not found: {model_path}")
exit(1)
main_model.load_state_dict(main_state_dict)
return main_model
#!/usr/bin/env python3
"""
This script is used for multi-speaker speech recognition.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
import json
import logging
import os
from itertools import zip_longest as zip_longest
import numpy as np
import torch
# chainer related
from chainer import training
from chainer.training import extensions
import espnet.lm.pytorch_backend.extlm as extlm_pytorch
import espnet.nets.pytorch_backend.lm.default as lm_pytorch
from espnet.asr.asr_mix_utils import add_results_to_json
from espnet.asr.asr_utils import (
CompareValueTrigger,
adadelta_eps_decay,
get_model_conf,
restore_snapshot,
snapshot_object,
torch_load,
torch_resume,
torch_snapshot,
)
from espnet.asr.pytorch_backend.asr import (
CustomEvaluator,
CustomUpdater,
load_trained_model,
)
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.pytorch_backend.e2e_asr_mix import pad_list
from espnet.utils.dataset import ChainerDataLoader, TransformDataset
from espnet.utils.deterministic_utils import set_deterministic_pytorch
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.iterators import ShufflingEnabler
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop
class CustomConverter(object):
"""Custom batch converter for Pytorch.
Args:
subsampling_factor (int): The subsampling factor.
dtype (torch.dtype): Data type to convert.
"""
def __init__(self, subsampling_factor=1, dtype=torch.float32, num_spkrs=2):
"""Initialize the converter."""
self.subsampling_factor = subsampling_factor
self.ignore_id = -1
self.dtype = dtype
self.num_spkrs = num_spkrs
def __call__(self, batch, device=torch.device("cpu")):
"""Transform a batch and send it to a device.
Args:
batch (list(tuple(str, dict[str, dict[str, Any]]))): The batch to transform.
device (torch.device): The device to send to.
Returns:
tuple(torch.Tensor, torch.Tensor, torch.Tensor): Transformed batch.
"""
# batch should be located in list
assert len(batch) == 1
xs, ys = batch[0][0], batch[0][-self.num_spkrs :]
# perform subsampling
if self.subsampling_factor > 1:
xs = [x[:: self.subsampling_factor, :] for x in xs]
# get batch of lengths of input sequences
ilens = np.array([x.shape[0] for x in xs])
# perform padding and convert to tensor
# currently only support real number
if xs[0].dtype.kind == "c":
xs_pad_real = pad_list(
[torch.from_numpy(x.real).float() for x in xs], 0
).to(device, dtype=self.dtype)
xs_pad_imag = pad_list(
[torch.from_numpy(x.imag).float() for x in xs], 0
).to(device, dtype=self.dtype)
# Note(kamo):
# {'real': ..., 'imag': ...} will be changed to ComplexTensor in E2E.
# Don't create ComplexTensor and give it to E2E here
# because torch.nn.DataParallel can't handle it.
xs_pad = {"real": xs_pad_real, "imag": xs_pad_imag}
else:
xs_pad = pad_list([torch.from_numpy(x).float() for x in xs], 0).to(
device, dtype=self.dtype
)
ilens = torch.from_numpy(ilens).to(device)
if not isinstance(ys[0], np.ndarray):
ys_pad = []
for i in range(len(ys)): # speakers
ys_pad += [torch.from_numpy(y).long() for y in ys[i]]
ys_pad = pad_list(ys_pad, self.ignore_id)
ys_pad = (
ys_pad.view(self.num_spkrs, -1, ys_pad.size(1))
.transpose(0, 1)
.to(device)
) # (B, num_spkrs, Tmax)
else:
ys_pad = pad_list(
[torch.from_numpy(y).long() for y in ys], self.ignore_id
).to(device)
return xs_pad, ilens, ys_pad
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# check cuda availability
if not torch.cuda.is_available():
logging.warning("cuda is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]["input"][0]["shape"][-1])
odim = int(valid_json[utts[0]]["output"][0]["shape"][-1])
logging.info("#input dims : " + str(idim))
logging.info("#output dims: " + str(odim))
# specify attention, CTC, hybrid mode
if args.mtlalpha == 1.0:
mtl_mode = "ctc"
logging.info("Pure CTC mode")
elif args.mtlalpha == 0.0:
mtl_mode = "att"
logging.info("Pure attention mode")
else:
mtl_mode = "mtl"
logging.info("Multitask learning mode")
# specify model architecture
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args)
assert isinstance(model, ASRInterface)
subsampling_factor = model.subsample[0]
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
torch.load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
reporter = model.reporter
# check the use of multi-gpu
if args.ngpu > 1:
if args.batch_size != 0:
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
args.batch_size *= args.ngpu
# set torch device
device = torch.device("cuda" if args.ngpu > 0 else "cpu")
if args.train_dtype in ("float16", "float32", "float64"):
dtype = getattr(torch, args.train_dtype)
else:
dtype = torch.float32
model = model.to(device=device, dtype=dtype)
logging.warning(
"num. model params: {:,} (num. trained: {:,} ({:.1f}%))".format(
sum(p.numel() for p in model.parameters()),
sum(p.numel() for p in model.parameters() if p.requires_grad),
sum(p.numel() for p in model.parameters() if p.requires_grad)
* 100.0
/ sum(p.numel() for p in model.parameters()),
)
)
# Setup an optimizer
if args.opt == "adadelta":
optimizer = torch.optim.Adadelta(
model.parameters(), rho=0.95, eps=args.eps, weight_decay=args.weight_decay
)
elif args.opt == "adam":
optimizer = torch.optim.Adam(model.parameters(), weight_decay=args.weight_decay)
elif args.opt == "noam":
from espnet.nets.pytorch_backend.transformer.optimizer import get_std_opt
optimizer = get_std_opt(
model.parameters(),
args.adim,
args.transformer_warmup_steps,
args.transformer_lr,
)
else:
raise NotImplementedError("unknown optimizer: " + args.opt)
# setup apex.amp
if args.train_dtype in ("O0", "O1", "O2", "O3"):
try:
from apex import amp
except ImportError as e:
logging.error(
f"You need to install apex for --train-dtype {args.train_dtype}. "
"See https://github.com/NVIDIA/apex#linux"
)
raise e
if args.opt == "noam":
model, optimizer.optimizer = amp.initialize(
model, optimizer.optimizer, opt_level=args.train_dtype
)
else:
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.train_dtype
)
use_apex = True
else:
use_apex = False
# FIXME: TOO DIRTY HACK
setattr(optimizer, "target", reporter)
setattr(optimizer, "serialize", lambda s: reporter.serialize(s))
# Setup a converter
converter = CustomConverter(
subsampling_factor=subsampling_factor, dtype=dtype, num_spkrs=args.num_spkrs
)
# read json data
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
# make minibatch list (variable length)
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=-1,
)
valid = make_batchset(
valid_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=-1,
)
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": True}, # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
# hack to make batchsize argument as 1
# actual bathsize is included in a list
# default collate function converts numpy array to pytorch tensor
# we used an empty collate function instead which returns list
train_iter = {
"main": ChainerDataLoader(
dataset=TransformDataset(train, lambda data: converter([load_tr(data)])),
batch_size=1,
num_workers=args.n_iter_processes,
shuffle=True,
collate_fn=lambda x: x[0],
)
}
valid_iter = {
"main": ChainerDataLoader(
dataset=TransformDataset(valid, lambda data: converter([load_cv(data)])),
batch_size=1,
shuffle=False,
collate_fn=lambda x: x[0],
num_workers=args.n_iter_processes,
)
}
# Set up a trainer
updater = CustomUpdater(
model,
args.grad_clip,
train_iter,
optimizer,
device,
args.ngpu,
args.grad_noise,
args.accum_grad,
use_apex=use_apex,
)
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
if use_sortagrad:
trainer.extend(
ShufflingEnabler([train_iter]),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
# Resume from a snapshot
if args.resume:
logging.info("resumed from %s" % args.resume)
torch_resume(args.resume, trainer)
# Evaluate the model with the test dataset for each epoch
trainer.extend(CustomEvaluator(model, valid_iter, reporter, device, args.ngpu))
# Save attention weight each epoch
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["input"][0]["shape"][1]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=device,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Make a plot for training and validation values
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att",
],
"epoch",
file_name="loss.png",
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
trainer.extend(
extensions.PlotReport(
["main/cer_ctc", "validation/main/cer_ctc"], "epoch", file_name="cer.png"
)
)
# Save best models
trainer.extend(
snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
if mtl_mode != "ctc":
trainer.extend(
snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# save snapshot which contains model and optimizer states
trainer.extend(torch_snapshot(), trigger=(1, "epoch"))
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc" and mtl_mode != "ctc":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.acc.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(
model, args.outdir + "/model.loss.best", load_fn=torch_load
),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_ctc",
"main/loss_att",
"validation/main/loss",
"validation/main/loss_ctc",
"validation/main/loss_att",
"main/acc",
"validation/main/acc",
"main/cer_ctc",
"validation/main/cer_ctc",
"elapsed_time",
]
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps",
lambda trainer: trainer.updater.get_optimizer("main").param_groups[0][
"eps"
],
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
if args.report_cer:
report_keys.append("validation/main/cer")
if args.report_wer:
report_keys.append("validation/main/wer")
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
from torch.utils.tensorboard import SummaryWriter
trainer.extend(
TensorboardLogger(SummaryWriter(args.tensorboard_dir), att_reporter),
trigger=(args.report_interval_iters, "iteration"),
)
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, ASRInterface)
model.recog_args = args
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
if getattr(rnnlm_args, "model_module", "default") != "default":
raise ValueError(
"use '--api v2' option to decode with non-default language model"
)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list),
rnnlm_args.layer,
rnnlm_args.unit,
getattr(rnnlm_args, "embed_unit", None), # for backward compatibility
)
)
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
)
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
)
)
else:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.LookAheadWordLM(
word_rnnlm.predictor, word_dict, char_dict
)
)
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info("gpu id: " + str(gpu_id))
model.cuda()
if rnnlm:
rnnlm.cuda()
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
new_js = {}
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False},
)
if args.batchsize == 0:
with torch.no_grad():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
else:
def grouper(n, iterable, fillvalue=None):
kargs = [iter(iterable)] * n
return zip_longest(*kargs, fillvalue=fillvalue)
# sort data if batchsize > 1
keys = list(js.keys())
if args.batchsize > 1:
feat_lens = [js[key]["input"][0]["shape"][0] for key in keys]
sorted_index = sorted(range(len(feat_lens)), key=lambda i: -feat_lens[i])
keys = [keys[i] for i in sorted_index]
with torch.no_grad():
for names in grouper(args.batchsize, keys, None):
names = [name for name in names if name]
batch = [(name, js[name]) for name in names]
feats = load_inputs_and_targets(batch)[0]
nbest_hyps = model.recognize_batch(
feats, args, train_args.char_list, rnnlm=rnnlm
)
for i, name in enumerate(names):
nbest_hyp = [hyp[i] for hyp in nbest_hyps]
new_js[name] = add_results_to_json(
js[name], nbest_hyp, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment