Commit 60a2c57a authored by sunzhq2's avatar sunzhq2 Committed by xuxo
Browse files

update conformer

parent 4a699441
name: Test import espnet
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
test_import:
runs-on: ${{ matrix.os }}
strategy:
max-parallel: 20
matrix:
os: [ubuntu-latest]
python-version: ["3.10"]
pytorch-version: [1.13.1]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt-get install -qq -y libsndfile1-dev
python3 -m pip install --upgrade pip setuptools wheel
- name: Install espnet with the least requirement
env:
TH_VERSION: ${{ matrix.pytorch-version }}
run: |
python3 -m pip install -U numba
./tools/installers/install_torch.sh false ${TH_VERSION} CPU
./tools/installers/install_chainer.sh CPU
python3 setup.py bdist_wheel
python3 -m pip install dist/espnet-*.whl
# log
python3 -m pip freeze
- name: Import all modules (Try1)
run: |
python3 ./ci/test_import_all.py
- name: Install espnet with the full requirement
env:
TH_VERSION: ${{ matrix.pytorch-version }}
run: |
python3 -m pip install "$(ls dist/espnet-*.whl)[all]"
# log
python3 -m pip freeze
- name: Import all modules (Try2)
run: |
python3 ./ci/test_import_all.py
name: Windows
on:
push:
branches:
- master
pull_request:
branches:
- master
jobs:
test_windows:
runs-on: Windows-latest
strategy:
matrix:
python-version: ["3.10"]
pytorch-version: [1.13.1]
defaults:
run:
shell: bash
steps:
- uses: actions/checkout@master
- uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.pytorch-version }}-${{ hashFiles('**/setup.py') }}-${{ hashFiles('**/Makefile') }}
- uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
architecture: 'x64'
- name: install dependencies
run: |
choco install -y wget
- name: install espnet
env:
ESPNET_PYTHON_VERSION: ${{ matrix.python-version }}
TH_VERSION: ${{ matrix.pytorch-version }}
CHAINER_VERSION: 6.0.0
USE_CONDA: false
run: |
./ci/install.sh
# general
*~
*.pyc
\#*\#
.\#*
*DS_Store
out.txt
espnet.egg-info/
doc/_build
slurm-*.out
tmp*
.eggs/
.hypothesis/
.idea
.pytest_cache/
__pycache__/
check_autopep8
.coverage
htmlcov
coverage.xml*
bats-core/
shellcheck*
check_shellcheck*
test_spm.vocab
test_spm.model
.vscode*
*.vim
*.swp
*.nfs*
# recipe related
egs*/*/*/data*
egs*/*/*/db
egs*/*/*/downloads
egs*/*/*/dump
egs*/*/*/enhan
egs*/*/*/exp
egs*/*/*/fbank
egs*/*/*/mfcc
egs*/*/*/stft
egs*/*/*/tensorboard
egs*/*/*/wav*
egs*/*/*/score*
egs*/*/*/nltk*
egs*/*/*/.cache*
egs*/*/*/pretrained_models*
egs*/fisher_callhome_spanish/*/local/mapping*
# tools related
tools/chainer
tools/bin
tools/include
tools/lib
tools/lib64
tools/bats-core
tools/chainer_ctc/
tools/kaldi*
tools/activate_python.sh
tools/miniconda.sh
tools/moses/
tools/mwerSegmenter/
tools/nkf/
tools/venv/
tools/sentencepiece/
tools/swig/
tools/warp-ctc/
tools/warp-transducer/
tools/*.done
tools/PESQ*
tools/hts_engine_API*
tools/open_jtalk*
tools/pyopenjtalk*
tools/tdmelodic_openjtalk*
tools/s3prl
tools/sctk*
tools/sph2pipe*
tools/espeak-ng*
tools/MBROLA*
tools/festival*
tools/speech_tools*
tools/phonemizer*
tools/py3mmseg
tools/anaconda
tools/ice-g2p
tools/fairseq
tools/RawNet
tools/._*
tools/ice-g2p*
tools/fairseq*
tools/featbin*
tools/miniconda
pull_request_rules:
- name: automatic merge if label=auto-merge
conditions:
- "label=auto-merge"
- "check-success=test_centos7"
- "check-success=test_debian11"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.10.2, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.11.0, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.12.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.7, 1.13.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.10.2, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.11.0, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.12.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.8, 1.13.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.10.2, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.11.0, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.12.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.9, 1.13.1, 6.0.0, false)"
- "check-success=linter_and_test (ubuntu-latest, 3.10, 1.13.1, false, 6.0.0)"
- "check-success=test_import (ubuntu-latest, 3.10, 1.13.1)"
- "check-success=check_kaldi_symlinks"
actions:
merge:
method: merge
- name: delete head branch after merged
conditions:
- merged
actions:
delete_head_branch: {}
- name: "add label=auto-merge for PR by mergify"
conditions:
- author=mergify[bot]
actions:
label:
add: ["auto-merge"]
- name: warn on conflicts
conditions:
- conflict
actions:
comment:
message: This pull request is now in conflict :(
label:
add: ["conflicts"]
- name: unlabel conflicts
conditions:
- -conflict
actions:
label:
remove: ["conflicts"]
- name: "auto add label=ESPnet1"
conditions:
- files~=^(espnet/|egs/)
actions:
label:
add: ["ESPnet1"]
- name: "auto add label=ESPnet2"
conditions:
- files~=^(espnet2/|egs2/)
actions:
label:
add: ["ESPnet2"]
- name: "auto add label=ASR"
conditions:
- files~=^(espnet*/asr|egs*/*/asr1)
actions:
label:
add: ["ASR"]
- name: "auto add label=TTS"
conditions:
- files~=^(espnet*/tts|egs*/*/tts1)
actions:
label:
add: ["TTS"]
- name: "auto add label=MT"
conditions:
- files~=^(espnet*/mt|egs*/*/mt1)
actions:
label:
add: ["MT"]
- name: "auto add label=LM"
conditions:
- files~=^(espnet*/lm)
actions:
label:
add: ["LM"]
- name: "auto add label=README"
conditions:
- files~=README.md
actions:
label:
add: ["README"]
- name: "auto add label=Documentation"
conditions:
- files~=^doc/
actions:
label:
add: ["Documentation"]
- name: "auto add label=CI"
conditions:
- files~=^(.circleci/|ci/|.github/|.travis.yml)
actions:
label:
add: ["CI"]
- name: "auto add label=Installation"
conditions:
- files~=^(tools/|setup.py)
actions:
label:
add: ["Installation"]
- name: "auto add label=mergify"
conditions:
- files~=^.mergify.yml
actions:
label:
add: ["mergify"]
- name: "auto add label=Docker"
conditions:
- files~=^docker/
actions:
label:
add: ["Docker"]
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: trailing-whitespace
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- id: end-of-file-fixer
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- id: check-yaml
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- id: check-added-large-files
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|tools/installers/patch_mwerSegmenter)
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|doc)
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
exclude: ^(egs2/TEMPLATE/asr1/utils|egs2/TEMPLATE/asr1/steps|egs2/TEMPLATE/tts1/sid|doc)
Requirement already satisfied: typeguard in /root/miniconda3/lib/python3.10/site-packages (2.13.3)
# How to contribute to ESPnet
## 1. What to contribute
If you are interested in contributing to ESPnet, your contributions will fall into three categories: major features, minor updates, and recipes.
### 1.1 Major features
If you want to ask or propose a new feature, please first open a new issue with the tag `Feature request`
or directly contact Shinji Watanabe <shinjiw@ieee.org> or other main developers. Each feature implementation
and design should be discussed and modified according to ongoing and future works.
You can find ongoing major development plans at https://github.com/espnet/espnet/milestones
or in https://github.com/espnet/espnet/issues (pinned issues)
### 1.2 Minor Updates (minor feature, bug-fix for an issue)
If you want to propose a minor feature, update an existing minor feature, or fix a bug, please first take a look at
the existing [issues](https://github.com/espnet/espnet/pulls) and/or [pull requests](https://github.com/espnet/espnet/pulls).
Pick an issue and comment on the task that you want to work on this feature.
If you need help or additional information to propose the feature, you can open a new issue with the tag `Discussion` and ask ESPnet members.
### 1.3 Recipes
ESPnet provides and maintains many example scripts, called `recipes`, that demonstrate how to
use the toolkit. The recipes for ESPnet1 are put under `egs` directory, while ESPnet2 ones are put under `egs2`.
Similar to Kaldi, each subdirectory of `egs` and `egs2` corresponds to a corpus that we have example scripts for.
#### 1.3.1 ESPnet1 recipes
ESPnet1 recipes (`egs/X`) follow the convention from [Kaldi](https://github.com/kaldi-asr/kaldi) and may rely on
several utilities available in Kaldi. As such, porting a new recipe from Kaldi to ESPnet is natural, and the user
may refer to [port-kaldi-recipe](https://github.com/espnet/espnet/wiki/How-to-port-the-Kaldi-recipe-to-the-ESPnet-recipe%3F)
and other existing recipes for new additions. For the Kaldi-style recipe architecture, please refer to
[Prepare-Kaldi-Style-Directory](https://kaldi-asr.org/doc/data_prep.html).
For each recipe, we ask you to report the following: experiments results and environnement, model information.
For reproducibility, a link to upload the pre-trained model may also be added. All this information should be written
in a markdown file called `RESULTS.md` and put at the recipe root. You can refer to
[tedlium2-example](https://github.com/espnet/espnet/blob/master/egs/tedlium2/asr1/RESULTS.md) for an example.
To generate `RESULTS.md` for a recipe, please follow the following instructions:
- Execute `~/espnet/utils/show_result.sh` at the recipe root (where `run.sh` is located).
You'll get your environment information and evaluation results for each experiment in a markdown format.
From here, you can copy or redirect text output to `RESULTS.md`.
- Execute `~/espnet/utils/pack_model.sh` at the recipe root to generate a packed ESPnet model called `model.tar.gz`
and output model information. Executing the utility script without argument will give you the expected arguments.
- Put the model information in `RESULTS.md` and model link if you're using a private web storage
- If you don't have private web storage, please contact Shinji Watanabe <shinjiw@ieee.org> to give you access to ESPnet storage.
#### 1.3.2 ESPnet2 recipes
ESPnet2's recipes correspond to `egs2`. ESPnet2 applies a new paradigm without dependencies of Kaldi's binaries, which makes it lighter and more generalized.
For ESPnet2, we do not recommend preparing the recipe's stages for each corpus but using the common pipelines we provided in `asr.sh`, `tts.sh`, and
`enh.sh`. For details of creating ESPnet2 recipes, please refer to [egs2-readme](https://github.com/espnet/espnet/blob/master/egs2/TEMPLATE/README.md).
The common pipeline of ESPnet2 recipes will take care of the `RESULTS.md` generation, model packing, and uploading. ESPnet2 models are maintained at Hugging Face and Zenodo (Deprecated).
You can also refer to the document in https://github.com/espnet/espnet_model_zoo
To upload your model, you need first (This is currently deprecated , uploading to Huggingface Hub is prefered) :
1. Sign up to Zenodo: https://zenodo.org/
2. Create access token: https://zenodo.org/account/settings/applications/tokens/new/
3. Set your environment: % export ACCESS_TOKEN="<your token>"
To port models from zenodo using Hugging Face hub,
1. Create a Hugging Face account - https://huggingface.co/
2. Request to be added to espnet organisation - https://huggingface.co/espnet
3. Go to `egs2/RECIPE/*` and run `./scripts/utils/upload_models_to_hub.sh "ZENODO_MODEL_NAME"`
To upload models using Huggingface-cli follow the following steps:
You can also refer to https://huggingface.co/docs/transformers/model_sharing
1. Create a Hugging Face account - https://huggingface.co/
2. Request to be added to espnet organisation - https://huggingface.co/espnet
3. Run huggingface-cli login (You can get the token request at this step under setting > Access Tokens > espnet token
4. `huggingface-cli repo create your-model-name --organization espnet`
5. `git clone https://huggingface.co/username/your-model-name` (clone this outside ESPNet to avoid issues as this a git repo)
6. `cd your-model-name`
7. `git lfs install`
8. copy contents from exp diretory of your recipe into this directory (Check other models of similar task under ESPNet to confirm your directory structure)
9. `git add . `
10. `git commit -m "Add model files"`
11. `git push`
12. Check if the inference demo on HF is running successfully to verify the upload
#### 1.3.3 Additional requirements for new recipe
- Common/shared files and directories such as `utils`, `steps`, `asr.sh`, etc, should be linked using
a symbolic link (e.g.: `ln -s <source-path> <target-path>`). Please refer to existing recipes if you're
unaware which files/directories are shared. Noted that in espnet2, some of them are automatically generated by https://github.com/espnet/espnet/blob/master/egs2/TEMPLATE/asr1/setup.sh.
- Default training and decoding configurations (i.e.: the default one in `run.sh`) should be named respectively `train.yaml`
and `decode.yaml` and put in `conf/`. Additional or variant configurations should be put in `conf/tuning/` and named accordingly
to its differences.
- If a recipe for a new corpus is proposed, you should add its name and information to:
https://github.com/espnet/espnet/blob/master/egs/README.md if it's a ESPnet1 recipe,
or https://github.com/espnet/espnet/blob/master/egs2/README.md + `db.sh` if it's a ESPnet2 recipe.
#### 1.3.4 Checklist before you submit the recipe-based PR
- [ ] be careful about the name for the recipe. It is recommended to follow naming conventions of the other recipes
- [ ] common/shared files are linked with **soft link** (see Section 1.3.3)
- [ ] modified or new python scripts should be passed through **latest** black formating (by using python package black). The command to be executed could be `black espnet espnet2 test utils setup.py egs*/*/*/local egs2/TEMPLATE/*/pyscripts tools/*.py ci/*.py`
- [ ] modified or new python scripts should be passed through **latest** isort formating (by using python package isort). The command to be executed could be `isort espnet espnet2 test utils setup.py egs*/*/*/local egs2/TEMPLATE/*/pyscripts tools/*.py ci/*.py`
- [ ] cluster settings should be set as **default** (e.g., cmd.sh conf/slurm.conf conf/queue.conf conf/pbs.conf)
- [ ] update `egs/README.md` or `egs2/README.md` with corresponding recipes
- [ ] add corresponding entry in `egs2/TEMPLATE/db.sh` for a new corpus
- [ ] try to **simplify** the model configurations. We recommend to have only the best configuration for the start of a recipe. Please also follow the default rule defined in Section 1.3.3
- [ ] large meta-information for a corpus should be maintained elsewhere other than in the recipe itself
- [ ] recommend to also include results and pre-trained model with the recipe
## 2 Pull Request
If your proposed feature or bugfix is ready, please open a Pull Request (PR) at https://github.com/espnet/espnet
or use the Pull Request button in your forked repo. If you're not familiar with the process, please refer to the following guides:
- http://stackoverflow.com/questions/14680711/how-to-do-a-github-pull-request
- https://help.github.com/articles/creating-a-pull-request/
## 3 Version policy and development branches
We basically develop in the `master` branch.
1. We will keep the first version digit `0` until we have some super major changes in the project organization level.
2. The second version digit will be updated when we have major updates, including new functions and refactoring, and
their related bug fix and recipe changes.
This version update will be done roughly every half year so far (but it depends on the development plan).
3. The third version digit will be updated when we fix serious bugs or accumulate some minor changes, including
recipe related changes periodically (every two months or so).
## 4 Unit testing
ESPnet's testing is located under `test/`. You can install additional packages for testing as follows:
``` console
$ cd <espnet_root>
$ . ./tools/activate_python.sh
$ pip install -e ".[test]"
```
### 4.1 Python
Then you can run the entire test suite using [flake8](http://flake8.pycqa.org/en/latest/), [autopep8](https://github.com/hhatto/autopep8), [black](https://github.com/psf/black), [isort](https://github.com/PyCQA/isort) and [pytest](https://docs.pytest.org/en/latest/) with [coverage](https://pytest-cov.readthedocs.io/en/latest/reporting.html) by
``` console
./ci/test_python.sh
```
Followings are some useful tips when you are using pytest:
- New test file should be put under `test/` directory and named `test_xxx.py`. Each method in the test file should
have the format `def test_yyy(...)`. [Pytest](https://docs.pytest.org/en/latest/) will automatically find and test them.
- We recommend adding several small test files instead of grouping them in one big file (e.g.: `test_e2e_xxx.py`).
Technically, a test file should only cover methods from one file (e.g.: `test_transformer_utils.py` to test `transformer_utils.py`).
- To monitor test coverage and avoid the overlapping test, we recommend using `pytest --cov-report term-missing <test_file|dir>`
to highlight covered and missed lines. For more details, please refer to [coverage-test](https://pytest-cov.readthedocs.io/en/latest/readme.html).
- We limited test running time to 2.0 seconds (see: [pytest-timeouts](https://pypi.org/project/pytest-timeouts/)). As such,
we recommend using small model parameters and avoiding dynamic imports, file access, and unnecessary loops. If a unit test needs
more running time, you can annotate your test with `@pytest.mark.execution_timeout(sec)`.
- For test initialization (parameters, modules, etc), you can use pytest fixtures. Refer to [pytest fixtures](https://docs.pytest.org/en/latest/fixture.html#using-fixtures-from-classes-modules-or-projects) for more information.
In addition, please follow the [PEP 8 convention](https://peps.python.org/pep-0008/) for the coding style and [Google's convention for docstrings](https://google.github.io/styleguide/pyguide.html#383-functions-and-methods).
Below are some specific points that should be taken care of in particular:
- [import ordering](https://peps.python.org/pep-0008/#imports)
- Avoid writing python2-style code. For example, `super().__init__()` is preferred over `super(CLASS_NAME, self).__init()__`.
### 4.2 Bash scripts
You can also test the scripts in `utils` with [bats-core](https://github.com/bats-core/bats-core) and [shellcheck](https://github.com/koalaman/shellcheck).
To test:
``` console
./ci/test_shell.sh
```
## 5 Integration testing
Write new integration tests in [ci/test_integration_espnet1.sh](ci/test_integration_espnet1.sh) or [ci/test_integration_espnet2.sh](ci/test_integration_espnet2.sh) when you add new features in [espnet/bin](espnet/bin) or [espnet2/bin](espnet2/bin), respectively. They use our smallest dataset [egs/mini_an4](egs/mini_an4) or [egs2/mini_an4](egs/mini_an4) to test `run.sh`. **Don't call `python` directly in integration tests. Instead, use `coverage run --append`** as a python interpreter. Especially, `run.sh` should support `--python ${python}` to call the custom interpreter.
```bash
# ci/test_integration_espnet{1,2}.sh
python="coverage run --append"
cd egs/mini_an4/your_task
./run.sh --python "${python}"
```
### 5.1 Configuration files
- [setup.cfg](setup.cfg) configures pytest, black and flake8.
- [.travis.yml](.travis.yml) configures Travis-CI (unittests, doc deploy).
- [.circleci/config.yml](.circleci/config.yml) configures Circle-CI (unittests, integration tests).
- [.github/workflows](.github/workflows/) configures Github Actions (unittests, integration tests).
- [codecov.yml](codecov.yml) configures CodeCov (code coverage).
## 6 Writing new tools
You can place your new tools under
- `espnet/bin`: heavy and large (e.g., neural network related) core tools.
- `utils`: lightweight self-contained python/bash scripts.
For `utils` scripts, do not forget to add help messages and test scripts under `test_utils`.
### 6.1 Python tools guideline
To generate doc, do not forget `def get_parser(): -> ArgumentParser` in the main file.
```python
#!/usr/bin/env python3
# Copyright XXX
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
# NOTE: do not forget this
def get_parser():
parser = argparse.ArgumentParser(
description="awsome tool", # DO NOT forget this
)
...
return parser
if __name__ == '__main__':
args = get_parser().parse_args()
...
```
### 6.2 Bash tools guideline
To generate doc, support `--help` to show its usage. If you use Kaldi's `utils/parse_option.sh`, define `help_message="Usage: $0 ..."`.
## 7 Writing documentation
See [doc](doc/README.md).
## 8 Adding pretrained models
Pack your trained models using `utils/pack_model.sh` and upload it [here](https://drive.google.com/open?id=1k9RRyc06Zl0mM2A7mi-hxNiNMFb_YzTF) (You require permission).
Add the shared link to `utils/recog_wav.sh` or `utils/synth_wav.sh` as follows:
```sh
"tedlium.demo") share_url="https://drive.google.com/open?id=1UqIY6WJMZ4sxNxSugUqp3mrGb3j6h7xe" ;;
```
The model name is arbitrary for now.
## 9 On CI failure
### 9.1 Travis CI and Github Actions
1. read the log from PR checks > details
### 9.2 Circle CI
1. read the log from PR checks > details
2. turn on Rerun workflow > Rerun job with SSH
3. open your local terminal and `ssh -p xxx xxx` (check circle ci log for the exact address)
4. try anything you can to pass the CI
### 9.3 Codecov
1. write more tests to increase coverage
2. explain to reviewers why you can't increase coverage
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
This diff is collapsed.
"""Initialize espnet package."""
import os
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
__version__ = f.read().strip()
#!/usr/bin/env python3
"""
This script is used to provide utility functions designed for multi-speaker ASR.
Copyright 2017 Johns Hopkins University (Shinji Watanabe)
Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
Most functions can be directly used as in asr_utils.py:
CompareValueTrigger, restore_snapshot, adadelta_eps_decay, chainer_load,
torch_snapshot, torch_save, torch_resume, AttributeDict, get_model_conf.
"""
import copy
import logging
import os
from chainer.training import extension
from espnet.asr.asr_utils import parse_hypothesis
# * -------------------- chainer extension related -------------------- *
class PlotAttentionReport(extension.Extension):
"""Plot attention reporter.
Args:
att_vis_fn (espnet.nets.*_backend.e2e_asr.calculate_all_attentions):
Function of attention visualization.
data (list[tuple(str, dict[str, dict[str, Any]])]): List json utt key items.
outdir (str): Directory to save figures.
converter (espnet.asr.*_backend.asr.CustomConverter):
CustomConverter object. Function to convert data.
device (torch.device): The destination device to send tensor.
reverse (bool): If True, input and output length are reversed.
"""
def __init__(self, att_vis_fn, data, outdir, converter, device, reverse=False):
"""Initialize PlotAttentionReport."""
self.att_vis_fn = att_vis_fn
self.data = copy.deepcopy(data)
self.outdir = outdir
self.converter = converter
self.device = device
self.reverse = reverse
if not os.path.exists(self.outdir):
os.makedirs(self.outdir)
def __call__(self, trainer):
"""Plot and save imaged matrix of att_ws."""
att_ws_sd = self.get_attention_weights()
for ns, att_ws in enumerate(att_ws_sd):
for idx, att_w in enumerate(att_ws):
filename = "%s/%s.ep.{.updater.epoch}.output%d.png" % (
self.outdir,
self.data[idx][0],
ns + 1,
)
att_w = self.get_attention_weight(idx, att_w, ns)
self._plot_and_save_attention(att_w, filename.format(trainer))
def log_attentions(self, logger, step):
"""Add image files of attention matrix to tensorboard."""
att_ws_sd = self.get_attention_weights()
for ns, att_ws in enumerate(att_ws_sd):
for idx, att_w in enumerate(att_ws):
att_w = self.get_attention_weight(idx, att_w, ns)
plot = self.draw_attention_plot(att_w)
logger.add_figure("%s" % (self.data[idx][0]), plot.gcf(), step)
plot.clf()
def get_attention_weights(self):
"""Return attention weights.
Returns:
arr_ws_sd (numpy.ndarray): attention weights. It's shape would be
differ from bachend.dtype=float
* pytorch-> 1) multi-head case => (B, H, Lmax, Tmax). 2)
other case => (B, Lmax, Tmax).
* chainer-> attention weights (B, Lmax, Tmax).
"""
batch = self.converter([self.converter.transform(self.data)], self.device)
att_ws_sd = self.att_vis_fn(*batch)
return att_ws_sd
def get_attention_weight(self, idx, att_w, spkr_idx):
"""Transform attention weight in regard to self.reverse."""
if self.reverse:
dec_len = int(self.data[idx][1]["input"][0]["shape"][0])
enc_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
else:
dec_len = int(self.data[idx][1]["output"][spkr_idx]["shape"][0])
enc_len = int(self.data[idx][1]["input"][0]["shape"][0])
if len(att_w.shape) == 3:
att_w = att_w[:, :dec_len, :enc_len]
else:
att_w = att_w[:dec_len, :enc_len]
return att_w
def draw_attention_plot(self, att_w):
"""Visualize attention weights matrix.
Args:
att_w(Tensor): Attention weight matrix.
Returns:
matplotlib.pyplot: pyplot object with attention matrix image.
"""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
if len(att_w.shape) == 3:
for h, aw in enumerate(att_w, 1):
plt.subplot(1, len(att_w), h)
plt.imshow(aw, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
else:
plt.imshow(att_w, aspect="auto")
plt.xlabel("Encoder Index")
plt.ylabel("Decoder Index")
plt.tight_layout()
return plt
def _plot_and_save_attention(self, att_w, filename):
plt = self.draw_attention_plot(att_w)
plt.savefig(filename)
plt.close()
def add_results_to_json(js, nbest_hyps_sd, char_list):
"""Add N-best results to json.
Args:
js (dict[str, Any]): Groundtruth utterance dict.
nbest_hyps_sd (list[dict[str, Any]]):
List of hypothesis for multi_speakers (# Utts x # Spkrs).
char_list (list[str]): List of characters.
Returns:
dict[str, Any]: N-best results added utterance dict.
"""
# copy old json info
new_js = dict()
new_js["utt2spk"] = js["utt2spk"]
num_spkrs = len(nbest_hyps_sd)
new_js["output"] = []
for ns in range(num_spkrs):
tmp_js = []
nbest_hyps = nbest_hyps_sd[ns]
for n, hyp in enumerate(nbest_hyps, 1):
# parse hypothesis
rec_text, rec_token, rec_tokenid, score = parse_hypothesis(hyp, char_list)
# copy ground-truth
out_dic = dict(js["output"][ns].items())
# update name
out_dic["name"] += "[%d]" % n
# add recognition results
out_dic["rec_text"] = rec_text
out_dic["rec_token"] = rec_token
out_dic["rec_tokenid"] = rec_tokenid
out_dic["score"] = score
# add to list of N-best result dicts
tmp_js.append(out_dic)
# show 1-best result
if n == 1:
logging.info("groundtruth: %s" % out_dic["text"])
logging.info("prediction : %s" % out_dic["rec_text"])
new_js["output"].append(tmp_js)
return new_js
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Training/decoding definition for the speech recognition task."""
import json
import logging
import os
# chainer related
import chainer
from chainer import training
from chainer.datasets import TransformDataset
from chainer.training import extensions
# rnnlm
import espnet.lm.chainer_backend.extlm as extlm_chainer
import espnet.lm.chainer_backend.lm as lm_chainer
# espnet related
from espnet.asr.asr_utils import (
CompareValueTrigger,
adadelta_eps_decay,
add_results_to_json,
chainer_load,
get_model_conf,
restore_snapshot,
)
from espnet.nets.asr_interface import ASRInterface
from espnet.utils.deterministic_utils import set_deterministic_chainer
from espnet.utils.dynamic_import import dynamic_import
from espnet.utils.io_utils import LoadInputsAndTargets
from espnet.utils.training.batchfy import make_batchset
from espnet.utils.training.evaluator import BaseEvaluator
from espnet.utils.training.iterators import (
ShufflingEnabler,
ToggleableShufflingMultiprocessIterator,
ToggleableShufflingSerialIterator,
)
from espnet.utils.training.tensorboard_logger import TensorboardLogger
from espnet.utils.training.train_utils import check_early_stop, set_early_stop
def train(args):
"""Train with the given args.
Args:
args (namespace): The program arguments.
"""
# display chainer version
logging.info("chainer version = " + chainer.__version__)
set_deterministic_chainer(args)
# check cuda and cudnn availability
if not chainer.cuda.available:
logging.warning("cuda is not available")
if not chainer.cuda.cudnn_enabled:
logging.warning("cudnn is not available")
# get input and output dimension info
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]["input"][0]["shape"][1])
odim = int(valid_json[utts[0]]["output"][0]["shape"][1])
logging.info("#input dims : " + str(idim))
logging.info("#output dims: " + str(odim))
# specify attention, CTC, hybrid mode
if args.mtlalpha == 1.0:
mtl_mode = "ctc"
logging.info("Pure CTC mode")
elif args.mtlalpha == 0.0:
mtl_mode = "att"
logging.info("Pure attention mode")
else:
mtl_mode = "mtl"
logging.info("Multitask learning mode")
# specify model architecture
logging.info("import model module: " + args.model_module)
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args, flag_return=False)
assert isinstance(model, ASRInterface)
total_subsampling_factor = model.get_total_subsampling_factor()
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + "/model.json"
with open(model_conf, "wb") as f:
logging.info("writing a model config file to " + model_conf)
f.write(
json.dumps(
(idim, odim, vars(args)), indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
# Set gpu
ngpu = args.ngpu
if ngpu == 1:
gpu_id = 0
# Make a specified GPU current
chainer.cuda.get_device_from_id(gpu_id).use()
model.to_gpu() # Copy the model to the GPU
logging.info("single gpu calculation.")
elif ngpu > 1:
gpu_id = 0
devices = {"main": gpu_id}
for gid in range(1, ngpu):
devices["sub_%d" % gid] = gid
logging.info("multi gpu calculation (#gpus = %d)." % ngpu)
logging.warning(
"batch size is automatically increased (%d -> %d)"
% (args.batch_size, args.batch_size * args.ngpu)
)
else:
gpu_id = -1
logging.info("cpu calculation")
# Setup an optimizer
if args.opt == "adadelta":
optimizer = chainer.optimizers.AdaDelta(eps=args.eps)
elif args.opt == "adam":
optimizer = chainer.optimizers.Adam()
elif args.opt == "noam":
optimizer = chainer.optimizers.Adam(alpha=0, beta1=0.9, beta2=0.98, eps=1e-9)
else:
raise NotImplementedError("args.opt={}".format(args.opt))
optimizer.setup(model)
optimizer.add_hook(chainer.optimizer.GradientClipping(args.grad_clip))
# Setup a converter
converter = model.custom_converter(subsampling_factor=model.subsample[0])
# read json data
with open(args.train_json, "rb") as f:
train_json = json.load(f)["utts"]
with open(args.valid_json, "rb") as f:
valid_json = json.load(f)["utts"]
# set up training iterator and updater
load_tr = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": True}, # Switch the mode of preprocessing
)
load_cv = LoadInputsAndTargets(
mode="asr",
load_output=True,
preprocess_conf=args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
use_sortagrad = args.sortagrad == -1 or args.sortagrad > 0
accum_grad = args.accum_grad
if ngpu <= 1:
# make minibatch list (variable length)
train = make_batchset(
train_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
shortest_first=use_sortagrad,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
# hack to make batchsize argument as 1
# actual batchsize is included in a list
if args.n_iter_processes > 0:
train_iters = [
ToggleableShufflingMultiprocessIterator(
TransformDataset(train, load_tr),
batch_size=1,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
shuffle=not use_sortagrad,
)
]
else:
train_iters = [
ToggleableShufflingSerialIterator(
TransformDataset(train, load_tr),
batch_size=1,
shuffle=not use_sortagrad,
)
]
# set up updater
updater = model.custom_updater(
train_iters[0],
optimizer,
converter=converter,
device=gpu_id,
accum_grad=accum_grad,
)
else:
if args.batch_count not in ("auto", "seq") and args.batch_size == 0:
raise NotImplementedError(
"--batch-count 'bin' and 'frame' are not implemented "
"in chainer multi gpu"
)
# set up minibatches
train_subsets = []
for gid in range(ngpu):
# make subset
train_json_subset = {
k: v for i, (k, v) in enumerate(train_json.items()) if i % ngpu == gid
}
# make minibatch list (variable length)
train_subsets += [
make_batchset(
train_json_subset,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
)
]
# each subset must have same length for MultiprocessParallelUpdater
maxlen = max([len(train_subset) for train_subset in train_subsets])
for train_subset in train_subsets:
if maxlen != len(train_subset):
for i in range(maxlen - len(train_subset)):
train_subset += [train_subset[i]]
# hack to make batchsize argument as 1
# actual batchsize is included in a list
if args.n_iter_processes > 0:
train_iters = [
ToggleableShufflingMultiprocessIterator(
TransformDataset(train_subsets[gid], load_tr),
batch_size=1,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
shuffle=not use_sortagrad,
)
for gid in range(ngpu)
]
else:
train_iters = [
ToggleableShufflingSerialIterator(
TransformDataset(train_subsets[gid], load_tr),
batch_size=1,
shuffle=not use_sortagrad,
)
for gid in range(ngpu)
]
# set up updater
updater = model.custom_parallel_updater(
train_iters, optimizer, converter=converter, devices=devices
)
# Set up a trainer
trainer = training.Trainer(updater, (args.epochs, "epoch"), out=args.outdir)
if use_sortagrad:
trainer.extend(
ShufflingEnabler(train_iters),
trigger=(args.sortagrad if args.sortagrad != -1 else args.epochs, "epoch"),
)
if args.opt == "noam":
from espnet.nets.chainer_backend.transformer.training import VaswaniRule
trainer.extend(
VaswaniRule(
"alpha",
d=args.adim,
warmup_steps=args.transformer_warmup_steps,
scale=args.transformer_lr,
),
trigger=(1, "iteration"),
)
# Resume from a snapshot
if args.resume:
chainer.serializers.load_npz(args.resume, trainer)
# set up validation iterator
valid = make_batchset(
valid_json,
args.batch_size,
args.maxlen_in,
args.maxlen_out,
args.minibatches,
min_batch_size=args.ngpu if args.ngpu > 1 else 1,
count=args.batch_count,
batch_bins=args.batch_bins,
batch_frames_in=args.batch_frames_in,
batch_frames_out=args.batch_frames_out,
batch_frames_inout=args.batch_frames_inout,
iaxis=0,
oaxis=0,
)
if args.n_iter_processes > 0:
valid_iter = chainer.iterators.MultiprocessIterator(
TransformDataset(valid, load_cv),
batch_size=1,
repeat=False,
shuffle=False,
n_processes=args.n_iter_processes,
n_prefetch=8,
maxtasksperchild=20,
)
else:
valid_iter = chainer.iterators.SerialIterator(
TransformDataset(valid, load_cv), batch_size=1, repeat=False, shuffle=False
)
# Evaluate the model with the test dataset for each epoch
trainer.extend(BaseEvaluator(valid_iter, model, converter=converter, device=gpu_id))
# Save attention weight each epoch
if args.num_save_attention > 0 and args.mtlalpha != 1.0:
data = sorted(
list(valid_json.items())[: args.num_save_attention],
key=lambda x: int(x[1]["input"][0]["shape"][1]),
reverse=True,
)
if hasattr(model, "module"):
att_vis_fn = model.module.calculate_all_attentions
plot_class = model.module.attention_plot_class
else:
att_vis_fn = model.calculate_all_attentions
plot_class = model.attention_plot_class
logging.info("Using custom PlotAttentionReport")
att_reporter = plot_class(
att_vis_fn,
data,
args.outdir + "/att_ws",
converter=converter,
transform=load_cv,
device=gpu_id,
subsampling_factor=total_subsampling_factor,
)
trainer.extend(att_reporter, trigger=(1, "epoch"))
else:
att_reporter = None
# Take a snapshot for each specified epoch
trainer.extend(
extensions.snapshot(filename="snapshot.ep.{.updater.epoch}"),
trigger=(1, "epoch"),
)
# Make a plot for training and validation values
trainer.extend(
extensions.PlotReport(
[
"main/loss",
"validation/main/loss",
"main/loss_ctc",
"validation/main/loss_ctc",
"main/loss_att",
"validation/main/loss_att",
],
"epoch",
file_name="loss.png",
)
)
trainer.extend(
extensions.PlotReport(
["main/acc", "validation/main/acc"], "epoch", file_name="acc.png"
)
)
# Save best models
trainer.extend(
extensions.snapshot_object(model, "model.loss.best"),
trigger=training.triggers.MinValueTrigger("validation/main/loss"),
)
if mtl_mode != "ctc":
trainer.extend(
extensions.snapshot_object(model, "model.acc.best"),
trigger=training.triggers.MaxValueTrigger("validation/main/acc"),
)
# epsilon decay in the optimizer
if args.opt == "adadelta":
if args.criterion == "acc" and mtl_mode != "ctc":
trainer.extend(
restore_snapshot(model, args.outdir + "/model.acc.best"),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/acc",
lambda best_value, current_value: best_value > current_value,
),
)
elif args.criterion == "loss":
trainer.extend(
restore_snapshot(model, args.outdir + "/model.loss.best"),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
trainer.extend(
adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
"validation/main/loss",
lambda best_value, current_value: best_value < current_value,
),
)
# Write a log of evaluation statistics for each epoch
trainer.extend(
extensions.LogReport(trigger=(args.report_interval_iters, "iteration"))
)
report_keys = [
"epoch",
"iteration",
"main/loss",
"main/loss_ctc",
"main/loss_att",
"validation/main/loss",
"validation/main/loss_ctc",
"validation/main/loss_att",
"main/acc",
"validation/main/acc",
"elapsed_time",
]
if args.opt == "adadelta":
trainer.extend(
extensions.observe_value(
"eps", lambda trainer: trainer.updater.get_optimizer("main").eps
),
trigger=(args.report_interval_iters, "iteration"),
)
report_keys.append("eps")
trainer.extend(
extensions.PrintReport(report_keys),
trigger=(args.report_interval_iters, "iteration"),
)
trainer.extend(extensions.ProgressBar(update_interval=args.report_interval_iters))
set_early_stop(trainer, args)
if args.tensorboard_dir is not None and args.tensorboard_dir != "":
try:
from tensorboardX import SummaryWriter
except Exception:
logging.error("Please install tensorboardx")
raise
writer = SummaryWriter(args.tensorboard_dir)
trainer.extend(
TensorboardLogger(writer, att_reporter),
trigger=(args.report_interval_iters, "iteration"),
)
# Run the training
trainer.run()
check_early_stop(trainer, args.epochs)
def recog(args):
"""Decode with the given args.
Args:
args (namespace): The program arguments.
"""
# display chainer version
logging.info("chainer version = " + chainer.__version__)
set_deterministic_chainer(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
for key in sorted(vars(args).keys()):
logging.info("ARGS: " + key + ": " + str(vars(args)[key]))
# specify model architecture
logging.info("reading model parameters from " + args.model)
# To be compatible with v.0.3.0 models
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.chainer_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
chainer_load(args.model, model)
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_chainer.ClassifierWithState(
lm_chainer.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit
)
)
chainer_load(args.rnnlm, rnnlm)
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_chainer.ClassifierWithState(
lm_chainer.RNNLM(len(word_dict), rnnlm_args.layer, rnnlm_args.unit)
)
chainer_load(args.word_rnnlm, word_rnnlm)
if rnnlm is not None:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.MultiLevelLM(
word_rnnlm.predictor, rnnlm.predictor, word_dict, char_dict
)
)
else:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.LookAheadWordLM(
word_rnnlm.predictor, word_dict, char_dict
)
)
# read json data
with open(args.recog_json, "rb") as f:
js = json.load(f)["utts"]
load_inputs_and_targets = LoadInputsAndTargets(
mode="asr",
load_output=False,
sort_in_input_length=False,
preprocess_conf=train_args.preprocess_conf
if args.preprocess_conf is None
else args.preprocess_conf,
preprocess_args={"train": False}, # Switch the mode of preprocessing
)
# decode each utterance
new_js = {}
with chainer.no_backprop_mode():
for idx, name in enumerate(js.keys(), 1):
logging.info("(%d/%d) decoding " + name, idx, len(js.keys()))
batch = [(name, js[name])]
feat = load_inputs_and_targets(batch)[0][0]
nbest_hyps = model.recognize(feat, args, train_args.char_list, rnnlm)
new_js[name] = add_results_to_json(
js[name], nbest_hyps, train_args.char_list
)
with open(args.result_label, "wb") as f:
f.write(
json.dumps(
{"utts": new_js}, indent=4, ensure_ascii=False, sort_keys=True
).encode("utf_8")
)
"""Finetuning methods."""
import logging
import os
import re
from collections import OrderedDict
import torch
from espnet.asr.asr_utils import get_model_conf, torch_load
from espnet.nets.asr_interface import ASRInterface
from espnet.nets.mt_interface import MTInterface
from espnet.nets.pytorch_backend.transducer.utils import custom_torch_load
from espnet.nets.tts_interface import TTSInterface
from espnet.utils.dynamic_import import dynamic_import
def freeze_modules(model, modules):
"""Freeze model parameters according to modules list.
Args:
model (torch.nn.Module): Main model.
modules (List): Specified module(s) to freeze.
Return:
model (torch.nn.Module) : Updated main model.
model_params (filter): Filtered model parameters.
"""
for mod, param in model.named_parameters():
if any(mod.startswith(m) for m in modules):
logging.warning(f"Freezing {mod}. It will not be updated during training.")
param.requires_grad = False
model_params = filter(lambda x: x.requires_grad, model.parameters())
return model, model_params
def transfer_verification(model_state_dict, partial_state_dict, modules):
"""Verify tuples (key, shape) for input model modules match specified modules.
Args:
model_state_dict (Dict) : Main model state dict.
partial_state_dict (Dict): Pre-trained model state dict.
modules (List): Specified module(s) to transfer.
Return:
(bool): Whether transfer learning is allowed.
"""
model_modules = []
partial_modules = []
for key_m, value_m in model_state_dict.items():
if any(key_m.startswith(m) for m in modules):
model_modules += [(key_m, value_m.shape)]
model_modules = sorted(model_modules, key=lambda x: (x[0], x[1]))
for key_p, value_p in partial_state_dict.items():
if any(key_p.startswith(m) for m in modules):
partial_modules += [(key_p, value_p.shape)]
partial_modules = sorted(partial_modules, key=lambda x: (x[0], x[1]))
module_match = model_modules == partial_modules
if not module_match:
logging.error(
"Some specified modules from the pre-trained model "
"don't match with the new model modules:"
)
logging.error(f"Pre-trained: {set(partial_modules) - set(model_modules)}")
logging.error(f"New model: {set(model_modules) - set(partial_modules)}")
exit(1)
return module_match
def get_partial_state_dict(model_state_dict, modules):
"""Create state dict with specified modules matching input model modules.
Args:
model_state_dict (Dict): Pre-trained model state dict.
modules (Dict): Specified module(s) to transfer.
Return:
new_state_dict (Dict): State dict with specified modules weights.
"""
new_state_dict = OrderedDict()
for key, value in model_state_dict.items():
if any(key.startswith(m) for m in modules):
new_state_dict[key] = value
return new_state_dict
def get_lm_state_dict(lm_state_dict):
"""Create compatible ASR decoder state dict from LM state dict.
Args:
lm_state_dict (Dict): Pre-trained LM state dict.
Return:
new_state_dict (Dict): State dict with compatible key names.
"""
new_state_dict = OrderedDict()
for key, value in list(lm_state_dict.items()):
if key == "predictor.embed.weight":
new_state_dict["dec.embed.weight"] = value
elif key.startswith("predictor.rnn."):
_split = key.split(".")
new_key = "dec.decoder." + _split[2] + "." + _split[3] + "_l0"
new_state_dict[new_key] = value
return new_state_dict
def filter_modules(model_state_dict, modules):
"""Filter non-matched modules in model state dict.
Args:
model_state_dict (Dict): Pre-trained model state dict.
modules (List): Specified module(s) to transfer.
Return:
new_mods (List): Filtered module list.
"""
new_mods = []
incorrect_mods = []
mods_model = list(model_state_dict.keys())
for mod in modules:
if any(key.startswith(mod) for key in mods_model):
new_mods += [mod]
else:
incorrect_mods += [mod]
if incorrect_mods:
logging.error(
"Specified module(s) don't match or (partially match) "
f"available modules in model. You specified: {incorrect_mods}."
)
logging.error("The existing modules in model are:")
logging.error(f"{mods_model}")
exit(1)
return new_mods
def create_transducer_compatible_state_dict(
model_state_dict, encoder_type, encoder_units
):
"""Create a compatible transducer model state dict for transfer learning.
If RNN encoder modules from a non-Transducer model are found in
the pre-trained model state dict, the corresponding modules keys are
renamed for compatibility.
Args:
model_state_dict (Dict): Pre-trained model state dict
encoder_type (str): Type of pre-trained encoder.
encoder_units (int): Number of encoder units in pre-trained model.
Returns:
new_state_dict (Dict): Transducer compatible pre-trained model state dict.
"""
if encoder_type.endswith("p") or not encoder_type.endswith(("lstm", "gru")):
return model_state_dict
new_state_dict = OrderedDict()
rnn_key_name = "birnn" if "b" in encoder_type else "rnn"
for key, value in list(model_state_dict.items()):
if any(k in key for k in ["l_last", "nbrnn"]):
if "nbrnn" in key:
layer_name = rnn_key_name + re.search("_l([0-9]+)", key).group(1)
key = re.sub(
"_l([0-9]+)",
"_l0",
key.replace("nbrnn", layer_name),
)
if (encoder_units * 2) == value.size(-1):
value = value[:, :encoder_units] + value[:, encoder_units:]
new_state_dict[key] = value
return new_state_dict
def load_trained_model(model_path, training=True):
"""Load the trained model for recognition.
Args:
model_path (str): Path to model.***.best
training (bool): Training mode specification for transducer model.
Returns:
model (torch.nn.Module): Trained model.
train_args (Namespace): Trained model arguments.
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), "model.json")
)
logging.info(f"Reading model parameters from {model_path}")
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
# CTC Loss is not needed, default to builtin to prevent import errors
if hasattr(train_args, "ctc_type"):
train_args.ctc_type = "builtin"
model_class = dynamic_import(model_module)
if "transducer" in model_module:
model = model_class(idim, odim, train_args, training=training)
custom_torch_load(model_path, model, training=training)
else:
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
def get_trained_model_state_dict(model_path, new_is_transducer):
"""Extract the trained model state dict for pre-initialization.
Args:
model_path (str): Path to trained model.
new_is_transducer (bool): Whether the new model is Transducer-based.
Return:
(Dict): Trained model state dict.
"""
logging.info(f"Reading model parameters from {model_path}")
conf_path = os.path.join(os.path.dirname(model_path), "model.json")
if "rnnlm" in model_path:
return get_lm_state_dict(torch.load(model_path))
idim, odim, args = get_model_conf(model_path, conf_path)
if hasattr(args, "model_module"):
model_module = args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, args)
torch_load(model_path, model)
assert (
isinstance(model, MTInterface)
or isinstance(model, ASRInterface)
or isinstance(model, TTSInterface)
)
if new_is_transducer and "transducer" not in args.model_module:
return create_transducer_compatible_state_dict(
model.state_dict(),
args.etype,
args.eunits,
)
return model.state_dict()
def load_trained_modules(idim, odim, args, interface=ASRInterface):
"""Load ASR/MT/TTS model with pre-trained weights for specified modules.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
args Namespace: Model arguments.
interface (ASRInterface|MTInterface|TTSInterface): Model interface.
Return:
main_model (torch.nn.Module): Model with pre-initialized weights.
"""
def print_new_keys(state_dict, modules, model_path):
logging.info(f"Loading {modules} from model: {model_path}")
for k in state_dict.keys():
logging.warning(f"Overriding module {k}")
enc_model_path = args.enc_init
dec_model_path = args.dec_init
enc_modules = args.enc_init_mods
dec_modules = args.dec_init_mods
model_class = dynamic_import(args.model_module)
main_model = model_class(idim, odim, args)
assert isinstance(main_model, interface)
main_state_dict = main_model.state_dict()
logging.warning("Model(s) found for pre-initialization.")
for model_path, modules in [
(enc_model_path, enc_modules),
(dec_model_path, dec_modules),
]:
if model_path is not None:
if os.path.isfile(model_path):
model_state_dict = get_trained_model_state_dict(
model_path, "transducer" in args.model_module
)
modules = filter_modules(model_state_dict, modules)
partial_state_dict = get_partial_state_dict(model_state_dict, modules)
if partial_state_dict:
if transfer_verification(
main_state_dict, partial_state_dict, modules
):
print_new_keys(partial_state_dict, modules, model_path)
main_state_dict.update(partial_state_dict)
else:
logging.error(f"Specified model was not found: {model_path}")
exit(1)
main_model.load_state_dict(main_state_dict)
return main_model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment