"...composable_kernel_rocm.git" did not exist on "dcd3d21a4452fc74ede2b0d94eb750e5e55c3eb4"
Commit 12c90639 authored by “change”'s avatar “change”
Browse files

init

parent 417b607b
__pycache__/
\ No newline at end of file
[submodule "SpeechT5/fairseq"]
path = SpeechT5/fairseq
url = https://github.com/pytorch/fairseq
[submodule "Speech2C/fairseq"]
path = Speech2C/fairseq
url = https://github.com/facebookresearch/fairseq.git
[submodule "YiTrans/fairseq"]
path = YiTrans/fairseq
url = https://github.com/facebookresearch/fairseq
[submodule "SpeechLM/fairseq"]
path = SpeechLM/fairseq
url = https://github.com/facebookresearch/fairseq.git
[submodule "SpeechUT/fairseq"]
path = SpeechUT/fairseq
url = https://github.com/facebookresearch/fairseq.git
[submodule "VATLM/fairseq"]
path = VATLM/fairseq
url = https://github.com/facebookresearch/fairseq.git
[submodule "Speech2S/fairseq"]
path = Speech2S/fairseq
url = https://github.com/facebookresearch/fairseq.git
branch = adding_womenbios
[submodule "WavLLM/fairseq"]
path = WavLLM/fairseq
url = https://github.com/pytorch/fairseq.git
# Microsoft Open Source Code of Conduct
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
Resources:
- [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
- [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
- Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
MIT License
Copyright (c) Microsoft Corporation.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
# SpeechT5
## 论文
- https://arxiv.org/abs/2110.07205
## 开源代码
- https://github.com/microsoft/SpeechT5
## 模型结构
speechT5的核心是一个常规的**Transformer编码器-解码器**,为了使得同一个Transformer可以同时处理文本和语音数据,添加了**pre-nets****post-nets****pre-net**将输入的文本或语音转换为Transformer使用的隐藏表示;**post-net**从Transformer中获取输出并转换为文本或语音。
<div align="center">
<img src="./images/model_architecture.png"/>
</div>
## 算法原理
在预训练期间,同时使用所有的 per-nets 和 post-nets 。预训练后,整个编码器 - 解码器主干在单个任务上进行微调。这种经过微调的模型仅使用特定于给定任务的 per-nets 和 post-nets 。*例如:要将 SpeechT5 用于文本到语音转换,您需要将文本编码器 per-nets 交换为文本输入,将语音解码器 per-nets 和 post-nets 交换为语音输出。*
<div align="center">
<img src="./images/algorithm.png"/>
</div>
**注意: 即使微调模型一开始使用共享预训练模型的同一组权重,但最终版本最终还是完全不同。例如,您不能采用经过微调的 ASR 模型并换掉 per-nets 和 post-nets 来获得有效的 TTS 模型。SpeechT5 很灵活,但不是那么灵活。**
## 环境配置
### Docker (方法一)
**注意修改路径参数**
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
docker run -it --network=host --ipc=host --name=your_container_name --shm-size=32G --device=/dev/kfd --device=/dev/mkfd --device=/dev/dri -v /opt/hyhal:/opt/hyhal:ro -v /path/your_code_data/:/path/your_code_data/ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-centos7.6-dtk24.04-py310 /bin/bash
# 独立安装项目依赖包fairseq
git clone -b 1.0.0a0 http://developer.hpccube.com/codes/OpenDAS/fairseq.git
cd fairseq
pip3 install --editable ./
# 安装项目依赖包
cd speechT5_pytorch
pip3 install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/
```
### Dockerfile (方法二)
```
cd ./docker
docker build --no-cache -t speecht5 .
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
git clone -b 1.0.0a0 http://developer.hpccube.com/codes/OpenDAS/fairseq.git
cd fairseq
pip3 install --editable ./
cd speechT5_pytorch
pip3 install -r requirements.txt
```
### Anaconda (方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk24.04
python:python3.10
torch:2.1.0
torchvision:0.16.0
```
Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应
2、fairseq库需要单独安装,可参考如下命令:
```
git clone -b 1.0.0a0 http://developer.hpccube.com/codes/OpenDAS/fairseq.git
cd fairseq
pip3 install --editable ./
```
3、其他非特殊库直接按照requirements.txt安装
```
cd speechT5_pytorch
pip3 install -r requirements.txt
```
## 数据集
**内网可从此地址拷贝:/public/home/changhl/dataset/LibriSpeech**
- SCnet快速下载链接:
- [librispeech_asr数据集下载](http://113.200.138.88:18080/aidatasets/librispeech_asr_dummy)
- 官方下载链接:
- [librispeech_asr数据集下载](http://www.openslr.org/12)
`librispeech_asr`:语音识别数据集,数据集中包括音频文件以及文本转录文件。其中音频文件以flac格式存储,文本转录文件以txt格式存储。
```
LibriSpeech
├── train-clean-100
│ ├── 1272
│ │ ├── 1272-128104
│ │ │ ├── 1272-128104-0000.flac
│ │ │ ├── 1272-128104-0001.flac
│ │ │ ├── 1272-128104-0002.flac
│ │ │ ├── 1272-128104-0003.flac
│ │ │ ├── ...
│ │ │ ├── 1272-128104.trans.txt
│ │ └── ...
│ └── ...
├── train-clean-360
├── train-other-500
├── dev-clean
├── dev-other
├── test-clean
└── test-othe
```
- `train-clean-100`:包含大约 100 小时的清晰语音。
- `1272`:说话人ID(1272)。
- `1272-128104`:说话人ID(1272)-文本章节ID(128204)。
- `1272-128104-0000.flac`:说话人ID(1272)-文本章节ID(128204)-文本片段ID(0)的音频文件。
- `1272-128104.trans.txt`:说话人ID(1272)-文本章节ID(128204)的转录文本文件。
## 预训练模型
**ASR任务训练前先下载预训练好的权重文件、SPM_TOKENIZER、字典等文件**
- 官方下载地址:
- [speecht5初始权重](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base.pt)
- [SPM_TOKENIZER下载地址](https://drive.google.com/uc?export=download&id=1wClgQjXXoU2lmpbaEa1v2SqMbg7cAutq)
- [字典](https://huggingface.co/ajyy/SpeechT5/resolve/main/speecht5_base.pt)
## 训练
### ASR训练
**step1:生成训练集的train.tsv文件和valid.tsv文件**
```
cd speecht5_pytorch/SpeechT5/fairseq
python examples/wav2vec/wav2vec_manifest.py dataset/LibriSpeech/dev-clean --dest /public/home/changhl/py_project/train_0910 --ext flac --valid-percent 0.1
# python wav2vec_manifest.py 数据集目录 --dest tsv文件的存储目录 --ext 语音文件类型(flac) --valid-percent 0.1
# 数据集目录为dataset/LibriSpeech/dev-clean,因为选取整个数据集,训练耗时过长,可以选择所有的librispeech数据集。
```
**step2:生成step1生成的tsv文件的标签文件:train.txt和valid.txt**
```
cd speecht5_pytorch/SpeechT5/fairseq
python examples/wav2vec/libri_labels.py /public/home/changhl/py_project/train_0910/train.tsv --output-dir /public/home/changhl/py_project/train_0910 --output-name train
python examples/wav2vec/libri_labels.py /public/home/changhl/py_project/train_0910/valid.tsv --output-dir /public/home/changhl/py_project/train_0910 --output-name valid
# python libri_labels.py tsv文件目录 \
# --output-dir txt文件保存目录 \
# --output-name txt文件名模型(和tsv的文件名相同)
```
**step3:字典文件迁移**
将下载的dict.txt字典文件迁移至train.tsv的同一目录下
**step4:训练**
```
cd speecht5_pytorch
export HIP_VISIBLE_DEVICES=0 # 自行指定可见卡
bash asr_train.sh \
--dcu 1 \
--log /public/home/changhl/py_project/SpeechT5/log \
--td /public/home/changhl/py_project/train_0910 \
--res /public/home/changhl/py_project/SpeechT5/res \
--lab /public/home/changhl/py_project/train_0910 \
--token /public/home/changhl/dataset/spm_char.model \
--speecht5 /public/home/changhl/py_project/SpeechT5/SpeechT5/speecht5 \
--checkpoint /public/home/changhl/dataset/speecht5_base.pt \
--epoch 3
# 具体参数解析-h查看
```
- `dcu`:训练所用的卡数
- `log`:训练日志文件保存目录
- `td`:train.tsv文件和valid.tsv文件的所在目录
- `res`:训练结果.pt文件保存目录
- `lab`:train.txt和valid.txt文件的所在目录
- `token`:下载的spm_char.model文件路径——SPM_TOKENIZER文件
- `speecht5`:/speecht5_pytorch/SpeechT5/speecht5该目录的绝对路径
- `checkpoint`:预训练的初始权重路径
- `epoch`:训练次数
## 应用场景
### 算法分类
```
语音识别
```
### 热点应用行业
```
金融,通信,广媒
```
## 源码仓库及问题反馈
https://developer.hpccube.com/codes/modelzoo/SpeechT5_pytorch
## 参考
[GitHub - microsoft/SpeechT5](https://github.com/microsoft/SpeechT5/tree/main/SpeechT5)
This diff is collapsed.
<!-- BEGIN MICROSOFT SECURITY.MD V0.0.7 BLOCK -->
## Security
Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below.
## Reporting Security Issues
**Please do not report security vulnerabilities through public GitHub issues.**
Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report).
If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey).
You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc).
Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
* Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
* Full paths of source file(s) related to the manifestation of the issue
* The location of the affected source code (tag/branch/commit or direct URL)
* Any special configuration required to reproduce the issue
* Step-by-step instructions to reproduce the issue
* Proof-of-concept or exploit code (if possible)
* Impact of the issue, including how an attacker might exploit the issue
This information will help us triage your report more quickly.
If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs.
## Preferred Languages
We prefer all communications to be in English.
## Policy
Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd).
<!-- END MICROSOFT SECURITY.MD BLOCK -->
# Speech2C
> [**Speech2C**](https://arxiv.org/abs/2203.17113) (```INTERSPEECH 2022```): **Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data**
## Pre-Trained and Fine-tuned Models
| Model | Pre-training Dataset | Fine-tuning Dataset | Model |
| :------: | :----------------------------------------------: | :-----------------: | :-----: |
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | - | [Google Drive](https://drive.google.com/file/d/1nGZ0LWEwlLq2pz7o805YALsMr9irV0Za/view?usp=sharing) |
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [10 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1nWSAc-33LmcDQHzH8IjXVJsuk0JZTWgN/view?usp=sharing) |
| Speech2C | [960 hrs LibriSpeech](http://www.openslr.org/12) | [100 hrs LibriSpeech](http://www.openslr.org/12) | [Google Drive](https://drive.google.com/file/d/1LwbQ5Y3tKZoK3s1ayLQgsfLTFnmkKNZs/view?usp=sharing) |
## Language Model and Vocabulary
| Model | Dataset | Model | Vocabulary |
| :------: | :------: | :---: | :--------: |
| LM | [LibriSpeech LM Dataset](https://www.openslr.org/11/) | [Model](https://drive.google.com/file/d/1UDCcNJT1DlquSRw0iRAXH6GHlf6zK6-8/view?usp=sharing) | [Vocabulary](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt) |
## Setup
```
git submodule update --init Speech2C/fairseq
cd Speech2C/
pip install --editable fairseq/
```
## Data Preparation
Please follow the steps of data preparation for HuBERT in [here](https://github.com/facebookresearch/fairseq/tree/main/examples/hubert#data-preparation).
## Pre-Training
```
DATA_DIR=
LABEL_DIR=
FAIRSEQ_PATH=
python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
--config-dir speech2c/config \
--config-name speech2c_base_librispeech \
task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} task.labels='["km"]' \
model.label_rate=50 common.user_dir=SpeechT5/Speech2C/speech2c \
```
## Finetune
```
DATA_DIR=
LABEL_DIR=
FAIRSEQ_PATH=
W2V_PATH=
CONFIG_NAME=
python ${FAIRSEQ_PATH}/fairseq_cli/hydra_train.py \
--config-dir speech2c/config \
--config-name ${CONFIG_NAME} \
task.data=${DATA_DIR} task.label_dir=${LABEL_DIR} \
model.w2v_path=${W2V_PATH} common.user_dir=SpeechT5/Speech2C/speech2c \
```
## Inference
Note that joint CTC and decoder inference is only supported when the batch size is 1.
```
FAIRSEQ_PATH=
DATA_DIR=
LABEL_DIR=
BEAM_SIZE=
CTC_WEIGHT=
TEST_SET=
CHECKPOINT_PATH=
W2V_PATH=
python ${FAIRSEQ_PATH}/fairseq_cli/generate.py ${DATA_DIR} \
--label-dir ${LABEL_DIR} \
--path ${CHECKPOINT_PATH} \
--user-dir SpeechT5/Speech2C/speech2c \
--model-overrides "{'w2v_path': '${W2V_PATH}'}" \
--gen-subset ${TEST_SET} \
--task speech2c_pretraining \
--post-process letter \
--add-decoder \
--labels '["ltr"]' \
--fine-tuning \
--scoring wer \
--max-len-a 0 \
--max-len-b 620 \
--pad-audio \
--random-crop \
--ctc-weight ${CTC_WEIGHT} \
--max-tokens 8000000 \
--beam ${BEAM_SIZE} \
--single-target \
```
## Results on Librispeech
### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 10hr subset
| Model |LM | test-clean | test-other |
| ------------- |------------- | ----| ----|
| wav2vec2.0 Base | - | 11.1 | 17.6 |
| HuBERT Base | - | 10.1 | 16.8 |
| **Speech2C** | - | **7.8** | **13.1** |
| wav2vec 2.0 Base | 4-gram | 4.3 |9.5 |
| wav2vec 2.0 Base | Transf. |3.2 |7.8 |
| HuBERT Base | 4-gram |4.3 |9.4 |
| **Speech2C** | **Transf.** | **3.1** | **7.0** |
### Evaluation on the [LibriSpeech](http://www.openslr.org/12) 100hr subset
| Model |LM | test-clean | test-other |
| ------------- |------------- | ----| ----|
| wav2vec2.0 Base | - | 6.1 | 13.3 |
| wav2vec2.0 Large | - | 4.7 | 9.0 |
| HuBERT Base | - | 6.3 | 13.2 |
| SpeechT5 | - | 4.4 | 10.4 |
| Baseline | - | 5.0 | 11.9 |
| **Speech2C** | - | **4.3** |**9.0** |
| wav2vec 2.0 Base | 4-gram | 3.4 |8.0 |
| wav2vec 2.0 Base | Transf. | 2.6 | 6.3 |
| HuBERT Base | 4-gram | 3.4 |8.1 |
| SpeechT5 | Transf. | 2.4 |5.8 |
| Baseline | Transf. | 2.5 |6.3 |
| **Speech2C** | **Transf.** | **2.4** |**5.2** |
## License
This project is licensed under the license found in the LICENSE file in the root directory of this source tree.
Portions of the source code are based on the [FAIRSEQ](https://github.com/pytorch/fairseq).
[Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct)
## Reference
If you find our work is useful in your research, please cite the following paper:
```bibtex
@article{Ao2022Speech2C,
title = {Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data},
author = {Junyi Ao and Ziqiang Zhang and Long Zhou and Shujie Liu and Haizhou Li and Tom Ko and Lirong Dai and Jinyu Li and Yao Qian and Furu Wei},
eprint={2203.17113},
archivePrefix={arXiv},
primaryClass={cs.SD},
year={2022}
}
```
from . import data, tasks, criterions, models # noqa
\ No newline at end of file
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tblog
seed: 1337
checkpoint:
no_epoch_checkpoints: true
best_checkpoint_metric: dec_accuracy
maximize_best_checkpoint_metric: true
distributed_training:
ddp_backend: c10d
find_unused_parameters: true
distributed_world_size: 1
distributed_port: 29671
nprocs_per_node: 8
task:
_name: speech2c_pretraining
data: ???
fine_tuning: true
label_dir: ???
normalize: false # must be consistent with pre-training
labels: ["ltr"]
single_target: true
add_decoder: true
pad_audio: true
random_crop: false
dataset:
num_workers: 6
max_tokens: 3200000
skip_invalid_size_inputs_valid_test: true
train_subset: train_100h
valid_subset: dev_other
criterion:
_name: ctc_ce
zero_infinity: true
optimization:
max_update: 80000
lr: [0.00004]
sentence_avg: true
update_freq: [1]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: speech2c_ctc
w2v_path: ???
apply_mask: true
mask_prob: 0.65
mask_channel_prob: 0.5
mask_channel_length: 64
layerdrop: 0.1
decoder_layerdrop: 0.1
activation_dropout: 0.1
feature_grad_mult: 0.0
freeze_finetune_updates: 25000
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.w2v_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
tensorboard_logdir: tblog
seed: 1337
checkpoint:
save_interval: 5
keep_interval_updates: 1
no_epoch_checkpoints: true
best_checkpoint_metric: dec_accuracy
maximize_best_checkpoint_metric: true
distributed_training:
ddp_backend: c10d
find_unused_parameters: true
distributed_world_size: 1
distributed_port: 29671
nprocs_per_node: 8
task:
_name: speech2c_pretraining
data: ???
fine_tuning: true
label_dir: ???
normalize: false # must be consistent with pre-training
labels: ["ltr"]
single_target: true
add_decoder: true
pad_audio: true
random_crop: false
dataset:
num_workers: 6
max_tokens: 3200000
skip_invalid_size_inputs_valid_test: true
validate_after_updates: ${model.freeze_finetune_updates}
validate_interval: 5
train_subset: train_10h
valid_subset: dev_other
criterion:
_name: ctc_ce
zero_infinity: true
optimization:
max_update: 25000
lr: [2e-5]
sentence_avg: true
update_freq: [1]
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-08
lr_scheduler:
_name: tri_stage
phase_ratio: [0.1, 0.4, 0.5]
final_lr_scale: 0.05
model:
_name: speech2c_ctc
w2v_path: ???
apply_mask: true
mask_selection: static
mask_length: 10
mask_other: 0
mask_prob: 0.75
mask_channel_selection: static
mask_channel_length: 64
mask_channel_other: 0
mask_channel_prob: 0.5
layerdrop: 0.1
decoder_layerdrop: 0.1
dropout: 0.0
activation_dropout: 0.1
attention_dropout: 0.0
feature_grad_mult: 0.0
freeze_finetune_updates: 10000
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
- model.w2v_path
- dataset.train_subset
- dataset.valid_subset
- criterion.wer_kenlm_model
- criterion.wer_lexicon
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
# @package _group_
common:
fp16: true
log_format: json
log_interval: 200
seed: 1337
tensorboard_logdir: tblog
checkpoint:
save_interval_updates: 25000
keep_interval_updates: 1
no_epoch_checkpoints: true
distributed_training:
ddp_backend: no_c10d
distributed_backend: 'nccl'
distributed_world_size: 32
distributed_port: 29671
nprocs_per_node: 8
find_unused_parameters: true
task:
_name: speech2c_pretraining
data: ???
label_dir: ???
labels: ???
label_rate: ${model.label_rate}
sample_rate: 16000
max_sample_size: 250000
min_sample_size: 32000
pad_audio: false
random_crop: true
normalize: false # must be consistent with extractor
add_decoder: true
dataset:
num_workers: 6
max_tokens: 1400000
skip_invalid_size_inputs_valid_test: true
validate_interval: 5
validate_interval_updates: 10000
criterion:
_name: speech2c
pred_masked_weight: 1.0
pred_nomask_weight: 0.0
loss_weights: [10,]
optimization:
max_update: 400000
lr: [0.0005]
clip_norm: 10.0
optimizer:
_name: adam
adam_betas: (0.9,0.98)
adam_eps: 1e-06
weight_decay: 0.01
lr_scheduler:
_name: polynomial_decay
warmup_updates: 32000
model:
_name: speech2c
label_rate: ???
skip_masked: false
skip_nomask: false
mask_prob: 0.80
extractor_mode: default
conv_feature_layers: '[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
final_dim: 256
encoder_layerdrop: 0.05
dropout_input: 0.1
dropout_features: 0.1
dropout: 0.1
attention_dropout: 0.1
feature_grad_mult: 0.1
untie_final_proj: true
activation_dropout: 0.0
use_rel_pos_enc: true
decoder_dict_size: -1
hydra:
job:
config:
override_dirname:
kv_sep: '-'
item_sep: '__'
exclude_keys:
- run
- task.data
- task.label_dir
run:
dir: ???
sweep:
dir: ???
subdir: ${hydra.job.config_name}__${hydra.job.override_dirname}
import importlib
import os
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith(".py") and not file.startswith("_"):
criterion_name = file[: file.find(".py")]
importlib.import_module(
"speech2c.criterions." + criterion_name
)
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
from argparse import Namespace
from dataclasses import dataclass, field
from omegaconf import II
from typing import Optional
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import post_process
from fairseq.tasks import FairseqTask
from fairseq.logging.meters import safe_round
@dataclass
class CtcCeCriterionConfig(FairseqDataclass):
zero_infinity: bool = field(
default=False,
metadata={"help": "zero inf loss when source length <= target length"},
)
sentence_avg: bool = II("optimization.sentence_avg")
post_process: str = field(
default="letter",
metadata={
"help": "how to post process predictions into words. can be letter, "
"wordpiece, BPE symbols, etc. "
"See fairseq.data.data_utils.post_process() for full list of options"
},
)
wer_kenlm_model: Optional[str] = field(
default=None,
metadata={
"help": "if this is provided, use kenlm to compute wer (along with other wer_* args)"
},
)
wer_lexicon: Optional[str] = field(
default=None,
metadata={"help": "lexicon to use with wer_kenlm_model"},
)
wer_lm_weight: float = field(
default=2.0,
metadata={"help": "lm weight to use with wer_kenlm_model"},
)
wer_word_score: float = field(
default=-1.0,
metadata={"help": "lm word score to use with wer_kenlm_model"},
)
wer_args: Optional[str] = field(
default=None,
metadata={
"help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)"
},
)
dec_weight: float = field(
default=0.5,
metadata={"help": "weights for decoder CE Loss, loss will be ((1 - dec_weight) * hubert_loss + dec_weight * CE_Loss)"},
)
report_accuracy: bool = field(
default=True,
metadata={"help": "report decoder accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
label_smoothing: float = field(
default=0.1,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
@register_criterion("ctc_ce", dataclass=CtcCeCriterionConfig)
class CtcCeCriterion(FairseqCriterion):
def __init__(self, cfg: CtcCeCriterionConfig, task: FairseqTask):
super().__init__(task)
self.blank_idx = (
task.target_dictionary.index(task.blank_symbol)
if hasattr(task, "blank_symbol")
else 0
)
self.pad_idx = task.target_dictionary.pad()
self.eos_idx = task.target_dictionary.eos()
self.post_process = cfg.post_process
if cfg.wer_args is not None:
(
cfg.wer_kenlm_model,
cfg.wer_lexicon,
cfg.wer_lm_weight,
cfg.wer_word_score,
) = eval(cfg.wer_args)
if cfg.wer_kenlm_model is not None:
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
dec_args = Namespace()
dec_args.nbest = 1
dec_args.criterion = "ctc"
dec_args.kenlm_model = cfg.wer_kenlm_model
dec_args.lexicon = cfg.wer_lexicon
dec_args.beam = 50
dec_args.beam_size_token = min(50, len(task.target_dictionary))
dec_args.beam_threshold = min(50, len(task.target_dictionary))
dec_args.lm_weight = cfg.wer_lm_weight
dec_args.word_score = cfg.wer_word_score
dec_args.unk_weight = -math.inf
dec_args.sil_weight = 0
self.w2l_decoder = W2lKenLMDecoder(dec_args, task.target_dictionary)
else:
self.w2l_decoder = None
self.zero_infinity = cfg.zero_infinity
self.sentence_avg = cfg.sentence_avg
self.dec_weight = cfg.dec_weight
self.report_accuracy = cfg.report_accuracy
self.ignore_prefix_size = cfg.ignore_prefix_size
self.eps = cfg.label_smoothing
def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
lprobs = model.get_normalized_probs(
net_output, log_probs=True
).contiguous() # (T, B, C) from the encoder
if "src_lengths" in sample["net_input"]:
input_lengths = sample["net_input"]["src_lengths"]
else:
if net_output["padding_mask"] is not None:
non_padding_mask = ~net_output["padding_mask"]
input_lengths = non_padding_mask.long().sum(-1)
else:
input_lengths = lprobs.new_full(
(lprobs.size(1),), lprobs.size(0), dtype=torch.long
)
pad_mask = (sample["target"] != self.pad_idx) & (
sample["target"] != self.eos_idx
)
targets_flat = sample["target"].masked_select(pad_mask)
if "target_lengths" in sample:
target_lengths = sample["target_lengths"]
else:
target_lengths = pad_mask.sum(-1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
lprobs,
targets_flat,
input_lengths,
target_lengths,
blank=self.blank_idx,
reduction="sum",
zero_infinity=self.zero_infinity,
)
ntokens = (
sample["ntokens"] if "ntokens" in sample else target_lengths.sum().item()
)
sample_size = sample["target"].size(0) if self.sentence_avg else ntokens
logging_output = {}
if "decoder_target" in sample:
dec_sample_size = sample["target"].size(0) if self.sentence_avg else sample["dec_ntokens"]
dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
logging_output["ctc_loss"] = loss.item()
loss = (1 - self.dec_weight) * loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
logging_output["dec_loss"] = dec_loss.item()
logging_output["dec_nll_loss"] = dec_nll_loss.item()
logging_output["dec_sample_size"] = dec_sample_size
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
logging_output["dec_n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
logging_output = {
"loss": utils.item(loss.data), # * sample['ntokens'],
"ntokens": ntokens,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
**logging_output,
}
if not model.training:
import editdistance
with torch.no_grad():
lprobs_t = lprobs.transpose(0, 1).float().contiguous().cpu()
c_err = 0
c_len = 0
w_errs = 0
w_len = 0
wv_errs = 0
for lp, t, inp_l in zip(
lprobs_t,
sample["target_label"]
if "target_label" in sample
else sample["target"],
input_lengths,
):
lp = lp[:inp_l].unsqueeze(0)
decoded = None
if self.w2l_decoder is not None:
decoded = self.w2l_decoder.decode(lp)
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
if len(decoded) < 1:
decoded = None
else:
decoded = decoded[0]
p = (t != self.task.target_dictionary.pad()) & (
t != self.task.target_dictionary.eos()
)
targ = t[p]
targ_units = self.task.target_dictionary.string(targ)
targ_units_arr = targ.tolist()
toks = lp.argmax(dim=-1).unique_consecutive()
pred_units_arr = toks[toks != self.blank_idx].tolist()
c_err += editdistance.eval(pred_units_arr, targ_units_arr)
c_len += len(targ_units_arr)
targ_words = post_process(targ_units, self.post_process).split()
pred_units = self.task.target_dictionary.string(pred_units_arr)
pred_words_raw = post_process(pred_units, self.post_process).split()
if decoded is not None and "words" in decoded:
pred_words = decoded["words"]
w_errs += editdistance.eval(pred_words, targ_words)
wv_errs += editdistance.eval(pred_words_raw, targ_words)
else:
dist = editdistance.eval(pred_words_raw, targ_words)
w_errs += dist
wv_errs += dist
w_len += len(targ_words)
logging_output["wv_errors"] = wv_errs
logging_output["w_errors"] = w_errs
logging_output["w_total"] = w_len
logging_output["c_errors"] = c_err
logging_output["c_total"] = c_len
return loss, sample_size, logging_output
def compute_ce_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.pad_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.pad_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample["decoder_target"]
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training."""
loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs))
ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs))
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
sample_size = utils.item(
sum(log.get("sample_size", 0) for log in logging_outputs)
)
metrics.log_scalar(
"loss", loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar("ntokens", ntokens)
metrics.log_scalar("nsentences", nsentences)
if sample_size != ntokens:
metrics.log_scalar(
"nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3
)
c_errors = sum(log.get("c_errors", 0) for log in logging_outputs)
metrics.log_scalar("_c_errors", c_errors)
c_total = sum(log.get("c_total", 0) for log in logging_outputs)
metrics.log_scalar("_c_total", c_total)
w_errors = sum(log.get("w_errors", 0) for log in logging_outputs)
metrics.log_scalar("_w_errors", w_errors)
wv_errors = sum(log.get("wv_errors", 0) for log in logging_outputs)
metrics.log_scalar("_wv_errors", wv_errors)
w_total = sum(log.get("w_total", 0) for log in logging_outputs)
metrics.log_scalar("_w_total", w_total)
if c_total > 0:
metrics.log_derived(
"uer",
lambda meters: safe_round(
meters["_c_errors"].sum * 100.0 / meters["_c_total"].sum, 3
)
if meters["_c_total"].sum > 0
else float("nan"),
)
if w_total > 0:
metrics.log_derived(
"wer",
lambda meters: safe_round(
meters["_w_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
metrics.log_derived(
"raw_wer",
lambda meters: safe_round(
meters["_wv_errors"].sum * 100.0 / meters["_w_total"].sum, 3
)
if meters["_w_total"].sum > 0
else float("nan"),
)
if "dec_loss" in logging_outputs[0]:
ctc_loss_sum = sum(log.get("ctc_loss", 0) for log in logging_outputs)
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
)
metrics.log_scalar(
"ctc_loss", ctc_loss_sum / sample_size / math.log(2), sample_size, round=3
)
metrics.log_scalar(
"dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
)
metrics.log_derived(
"dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("dec_n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("dec_n_correct", n_correct)
metrics.log_derived(
"dec_accuracy",
lambda meters: round(
meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return True
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
import re
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import label_smoothed_nll_loss
from fairseq.criterions.hubert_criterion import HubertCriterionConfig
@dataclass
class Speech2cCriterionConfig(HubertCriterionConfig):
dec_weight: float = field(
default=1.0,
metadata={"help": "weights for decoder CE Loss, loss will be (hubert_loss + dec_weight * CE_Loss)"},
)
report_accuracy: bool = field(
default=True,
metadata={"help": "report decoder accuracy metric"},
)
ignore_prefix_size: int = field(
default=0,
metadata={"help": "Ignore first N tokens"},
)
label_smoothing: float = field(
default=0.0,
metadata={"help": "epsilon for label smoothing, 0 means no label smoothing"},
)
@register_criterion("speech2c", dataclass=Speech2cCriterionConfig)
class Speech2cCriterion(FairseqCriterion):
def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None, dec_weight=1.0, report_accuracy=False, ignore_prefix_size=0, label_smoothing=0.0):
super().__init__(task)
self.pred_masked_weight = pred_masked_weight
self.pred_nomask_weight = pred_nomask_weight
self.loss_weights = loss_weights
self.log_keys = [] if log_keys is None else log_keys
self.dec_weight = dec_weight
self.report_accuracy = report_accuracy
self.ignore_prefix_size = ignore_prefix_size
self.eps = label_smoothing
self.padding_idx = task.dictionaries[0].pad()
def forward(self, model, sample, reduce=True, log_pred=False):
"""Compute the loss for the given sample.
Returns a tuple with three elements:
1) the loss
2) the sample size, which is used as the denominator for the gradient
3) logging outputs to display while training
"""
net_output = model(target_list=sample["target_list"], **sample["net_input"])
loss = 0.0
sample_size = 0
logging_output = {}
reduction = "sum" if reduce else "none"
loss_m_list = []
logp_m_list = model.get_logits(net_output, True)
targ_m_list = model.get_targets(net_output, True)
assert self.pred_masked_weight == 0 or len(logp_m_list) > 0
for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)):
loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction)
loss_m_list.append(loss_m)
logging_output[f"loss_m_{i}"] = loss_m.detach().item()
if self.pred_masked_weight > 0:
loss += self.pred_masked_weight * sum(loss_m_list)
sample_size += targ_m_list[0].numel()
loss_u_list = []
logp_u_list = model.get_logits(net_output, False)
targ_u_list = model.get_targets(net_output, False)
assert self.pred_nomask_weight == 0 or len(logp_u_list) > 0
for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)):
loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction)
loss_u_list.append(loss_u)
logging_output[f"loss_u_{i}"] = loss_u.detach().item()
if self.pred_nomask_weight > 0:
loss += self.pred_nomask_weight * sum(loss_u_list)
sample_size += targ_u_list[0].numel()
if self.loss_weights is not None:
assert hasattr(model, "get_extra_losses")
extra_losses, names = model.get_extra_losses(net_output)
if torch.is_tensor(extra_losses):
extra_losses = [extra_losses]
names = [names]
if len(self.loss_weights) == 1 and len(extra_losses) != 1:
self.loss_weights = [self.loss_weights[0]] * len(extra_losses)
assert len(extra_losses) == len(
self.loss_weights
), f"{len(extra_losses)}, {len(self.loss_weights)}"
for p, n, coef in zip(extra_losses, names, self.loss_weights):
if coef != 0 and p is not None:
p = coef * p.float() * sample_size
loss += p
logging_output[f"loss_{n}"] = p.item()
if "decoder_target" in sample:
dec_sample_size = sample["dec_ntokens"]
dec_loss, dec_nll_loss = self.compute_ce_loss(model, net_output["decoder_out"], sample, reduce=reduce)
loss = loss + (self.dec_weight * dec_loss * sample_size / dec_sample_size)
logging_output["dec_loss"] = dec_loss.item()
logging_output["dec_nll_loss"] = dec_nll_loss.item()
logging_output["dec_sample_size"] = dec_sample_size
if self.report_accuracy:
n_correct, total = self.compute_accuracy(model, net_output["decoder_out"], sample)
logging_output["dec_n_correct"] = utils.item(n_correct.data)
logging_output["total"] = utils.item(total.data)
logging_output = {
"loss": loss.item() if reduce else loss,
"ntokens": sample_size,
"nsentences": sample["id"].numel(),
"sample_size": sample_size,
**logging_output,
}
for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
def compute_correct(logits):
if logits.numel() == 0:
return 0, 0
else:
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
return corr, count
with torch.no_grad():
for i, logp_m in enumerate(logp_m_list):
corr_m, count_m = compute_correct(logp_m)
logging_output[f"correct_m_{i}"] = corr_m
logging_output[f"count_m_{i}"] = count_m
for i, logp_u in enumerate(logp_u_list):
corr_u, count_u = compute_correct(logp_u)
logging_output[f"correct_u_{i}"] = corr_u
logging_output[f"count_u_{i}"] = count_u
return loss, sample_size, logging_output
def compute_ce_loss(self, model, net_output, sample, reduce=True):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
loss, nll_loss = label_smoothed_nll_loss(
lprobs,
target,
self.eps,
ignore_index=self.padding_idx,
reduce=reduce,
)
return loss, nll_loss
def compute_accuracy(self, model, net_output, sample):
lprobs, target = self.get_lprobs_and_target(model, net_output, sample)
mask = target.ne(self.padding_idx)
n_correct = torch.sum(
lprobs.argmax(1).masked_select(mask).eq(target.masked_select(mask))
)
total = torch.sum(mask)
return n_correct, total
def get_lprobs_and_target(self, model, net_output, sample):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample["decoder_target"]
if self.ignore_prefix_size > 0:
if getattr(lprobs, "batch_first", False):
lprobs = lprobs[:, self.ignore_prefix_size :, :].contiguous()
target = target[:, self.ignore_prefix_size :].contiguous()
else:
lprobs = lprobs[self.ignore_prefix_size :, :, :].contiguous()
target = target[self.ignore_prefix_size :, :].contiguous()
return lprobs.view(-1, lprobs.size(-1)), target.view(-1)
@staticmethod
def reduce_metrics(logging_outputs) -> None:
"""Aggregate logging outputs from data parallel training (copied from normal cross entropy)."""
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3)
if sample_size != ntokens:
metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3)
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg))
else:
metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg))
counts = {}
for lk in logging_outputs[0].keys():
if lk.startswith("count_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val)
counts[lk] = val
for lk in logging_outputs[0].keys():
if lk.startswith("loss_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / sample_size / math.log(2), round=3)
elif lk.startswith("correct_"):
val = sum(log[lk] for log in logging_outputs)
metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)])
if "dec_loss" in logging_outputs[0]:
dec_loss_sum = sum(log.get("dec_loss", 0) for log in logging_outputs)
dec_nll_loss_sum = sum(log.get("dec_nll_loss", 0) for log in logging_outputs)
dec_sample_size = sum(log.get("dec_sample_size", 0) for log in logging_outputs)
metrics.log_scalar(
"dec_loss", dec_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
)
metrics.log_scalar(
"dec_nll_loss", dec_nll_loss_sum / dec_sample_size / math.log(2), dec_sample_size, round=3
)
metrics.log_derived(
"dec_ppl", lambda meters: utils.get_perplexity(meters["dec_nll_loss"].avg)
)
total = utils.item(sum(log.get("total", 0) for log in logging_outputs))
if total > 0:
metrics.log_scalar("total", total)
n_correct = utils.item(
sum(log.get("dec_n_correct", 0) for log in logging_outputs)
)
metrics.log_scalar("dec_n_correct", n_correct)
metrics.log_derived(
"dec_accuracy",
lambda meters: round(
meters["dec_n_correct"].sum * 100.0 / meters["total"].sum, 3
)
if meters["total"].sum > 0
else float("nan"),
)
@staticmethod
def aggregate_logging_outputs(logging_outputs):
"""Aggregate logging outputs from data parallel training."""
raise NotImplementedError()
@staticmethod
def logging_outputs_can_be_summed() -> bool:
"""
Whether the logging outputs returned by `forward` can be summed
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import logging
from typing import Any, List, Optional, Union
import torch
from fairseq.data import data_utils, Dictionary
from fairseq.data.audio.hubert_dataset import HubertDataset
logger = logging.getLogger(__name__)
class Speech2cDataset(HubertDataset):
def __init__(
self,
manifest_path: str,
sample_rate: float,
label_paths: List[str],
label_rates: Union[List[float], float], # -1 for sequence labels
pad_list: List[str],
eos_list: List[str],
label_processors: Optional[List[Any]] = None,
max_keep_sample_size: Optional[int] = None,
min_keep_sample_size: Optional[int] = None,
max_sample_size: Optional[int] = None,
shuffle: bool = True,
pad_audio: bool = False,
normalize: bool = False,
store_labels: bool = True,
random_crop: bool = False,
single_target: bool = False,
tgt_dict: Optional[Dictionary] = None,
add_decoder: bool = False,
fine_tuning: bool = False,
):
super().__init__(
manifest_path,
sample_rate,
label_paths,
label_rates,
pad_list,
eos_list,
label_processors,
max_keep_sample_size,
min_keep_sample_size,
max_sample_size,
shuffle,
pad_audio,
normalize,
store_labels,
random_crop,
single_target
)
self.tgt_dict = tgt_dict
self.add_decoder = add_decoder
self.fine_tuning = fine_tuning
def collater(self, samples):
# target = max(sizes) -> random_crop not used
# target = max_sample_size -> random_crop used for long
samples = [s for s in samples if s["source"] is not None]
if len(samples) == 0:
return {}
audios = [s["source"] for s in samples]
audio_sizes = [len(s) for s in audios]
if self.pad_audio:
audio_size = min(max(audio_sizes), self.max_sample_size)
else:
audio_size = min(min(audio_sizes), self.max_sample_size)
collated_audios, padding_mask, audio_starts = self.collater_audio(
audios, audio_size
)
targets_by_label = [
[s["label_list"][i] for s in samples] for i in range(self.num_labels)
]
targets_list, lengths_list, ntokens_list = self.collater_label(
targets_by_label, audio_size, audio_starts
)
if self.add_decoder:
if self.fine_tuning:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]], torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
else:
decoder_label = [
torch.cat((targets_list[0][i, :lengths_list[0][i]].unique_consecutive(), torch.tensor([self.tgt_dict.eos()])), 0).long()
for i in range(targets_list[0].size(0))
]
dec_ntokens = sum(x.size(0) for x in decoder_label)
decoder_target = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=False,
)
decoder_target_lengths = torch.tensor(
[x.size(0) for x in decoder_label], dtype=torch.long
)
prev_output_tokens = data_utils.collate_tokens(
decoder_label,
self.tgt_dict.pad(),
self.tgt_dict.eos(),
left_pad=False,
move_eos_to_beginning=True,
)
net_input = {
"source": collated_audios,
"padding_mask": padding_mask,
"prev_output_tokens": prev_output_tokens,
}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
"decoder_target": decoder_target,
"decoder_target_lengths": decoder_target_lengths,
"dec_ntokens": dec_ntokens,
}
else:
net_input = {"source": collated_audios, "padding_mask": padding_mask}
batch = {
"id": torch.LongTensor([s["id"] for s in samples]),
"net_input": net_input,
}
if self.single_target:
batch["target_lengths"] = lengths_list[0]
batch["ntokens"] = ntokens_list[0]
batch["target"] = targets_list[0]
else:
batch["target_lengths_list"] = lengths_list
batch["ntokens_list"] = ntokens_list
batch["target_list"] = targets_list
return batch
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import numpy as np
import six
class CTCPrefixScore(object):
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the probablities of multiple labels
simultaneously
"""
def __init__(self, x, blank, eos, xp):
self.xp = xp
self.logzero = -10000000000.0
self.blank = blank
self.eos = eos
self.input_length = len(x)
self.x = x
def initial_state(self):
"""Obtain an initial CTC state
:return: CTC state
"""
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
r = self.xp.full((self.input_length, 2), self.logzero, dtype=np.float32)
r[0, 1] = self.x[0, self.blank]
for i in six.moves.range(1, self.input_length):
r[i, 1] = r[i - 1, 1] + self.x[i, self.blank]
return r
def __call__(self, y, cs, r_prev):
"""Compute CTC prefix scores for next labels
:param y : prefix label sequence
:param cs : array of next labels
:param r_prev: previous CTC state
:return ctc_scores, ctc_states
"""
# initialize CTC states
output_length = len(y) - 1 # ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
r = self.xp.ndarray((self.input_length, 2, len(cs)), dtype=np.float32)
xs = self.x[:, cs]
if output_length == 0:
r[0, 0] = xs[0]
r[0, 1] = self.logzero
else:
r[output_length - 1] = self.logzero
# prepare forward probabilities for the last label
r_sum = self.xp.logaddexp(
r_prev[:, 0], r_prev[:, 1]
) # log(r_t^n(g) + r_t^b(g))
last = y[-1]
if output_length > 0 and last in cs:
log_phi = self.xp.ndarray((self.input_length, len(cs)), dtype=np.float32)
for i in six.moves.range(len(cs)):
log_phi[:, i] = r_sum if cs[i] != last else r_prev[:, 1]
else:
log_phi = r_sum
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
# and log prefix probabilities log(psi)
start = max(output_length, 1)
log_psi = r[start - 1, 0]
for t in six.moves.range(start, self.input_length):
r[t, 0] = self.xp.logaddexp(r[t - 1, 0], log_phi[t - 1]) + xs[t]
r[t, 1] = (
self.xp.logaddexp(r[t - 1, 0], r[t - 1, 1]) + self.x[t, self.blank]
)
log_psi = self.xp.logaddexp(log_psi, log_phi[t - 1] + xs[t])
# get P(...eos|X) that ends with the prefix itself
eos_pos = self.xp.where(cs == self.eos)[0]
if len(eos_pos) > 0:
log_psi[eos_pos] = r_sum[-1] # log(r_T^n(g) + r_T^b(g))
# exclude blank probs
blank_pos = self.xp.where(cs == self.blank)[0]
if len(blank_pos) > 0:
log_psi[blank_pos] = self.logzero
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
return log_psi, self.xp.rollaxis(r, 2)
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from fairseq import utils
from torch import Tensor
from fairseq.modules import MultiheadAttention as FairseqMultiheadAttention
class MultiheadAttention(FairseqMultiheadAttention):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def __init__(
self,
embed_dim,
num_heads,
kdim=None,
vdim=None,
dropout=0.0,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
self_attention=False,
encoder_decoder_attention=False,
q_noise=0.0,
qn_block_size=8,
):
super().__init__(
embed_dim,
num_heads,
kdim,
vdim,
dropout,
bias,
add_bias_kv,
add_zero_attn,
self_attention,
encoder_decoder_attention,
q_noise,
qn_block_size,
)
def forward(
self,
query,
key: Optional[Tensor],
value: Optional[Tensor],
key_padding_mask: Optional[Tensor] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
need_weights: bool = True,
static_kv: bool = False,
attn_mask: Optional[Tensor] = None,
before_softmax: bool = False,
need_head_weights: bool = False,
position_bias: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if need_head_weights:
need_weights = True
is_tpu = query.device.type == "xla"
tgt_len, bsz, embed_dim = query.size()
src_len = tgt_len
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
assert list(query.size()) == [tgt_len, bsz, embed_dim]
if key is not None:
src_len, key_bsz, _ = key.size()
if not torch.jit.is_scripting():
assert key_bsz == bsz
assert value is not None
assert src_len, bsz == value.shape[:2]
if (
not self.onnx_trace
and not is_tpu # don't use PyTorch version on TPUs
and incremental_state is None
and not static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and not torch.jit.is_scripting()
and position_bias is None
):
assert key is not None and value is not None
return F.multi_head_attention_forward(
query,
key,
value,
self.embed_dim,
self.num_heads,
torch.empty([0]),
torch.cat((self.q_proj.bias, self.k_proj.bias, self.v_proj.bias)),
self.bias_k,
self.bias_v,
self.add_zero_attn,
self.dropout_module.p,
self.out_proj.weight,
self.out_proj.bias,
self.training or self.dropout_module.apply_during_inference,
key_padding_mask,
need_weights,
attn_mask,
use_separate_proj_weight=True,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
)
if incremental_state is not None:
saved_state = self._get_input_buffer(incremental_state)
if saved_state is not None and "prev_key" in saved_state:
# previous time steps are cached - no need to recompute
# key and value if they are static
if static_kv:
assert self.encoder_decoder_attention and not self.self_attention
key = value = None
else:
saved_state = None
if self.self_attention:
q = self.q_proj(query)
k = self.k_proj(query)
v = self.v_proj(query)
elif self.encoder_decoder_attention:
# encoder-decoder attention
q = self.q_proj(query)
if key is None:
assert value is None
k = v = None
else:
k = self.k_proj(key)
v = self.v_proj(key)
else:
assert key is not None and value is not None
q = self.q_proj(query)
k = self.k_proj(key)
v = self.v_proj(value)
q *= self.scaling
if self.bias_k is not None:
assert self.bias_v is not None
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
],
dim=1,
)
q = (
q.contiguous()
.view(tgt_len, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if k is not None:
k = (
k.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if v is not None:
v = (
v.contiguous()
.view(-1, bsz * self.num_heads, self.head_dim)
.transpose(0, 1)
)
if saved_state is not None:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if "prev_key" in saved_state:
_prev_key = saved_state["prev_key"]
assert _prev_key is not None
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
k = prev_key
else:
assert k is not None
k = torch.cat([prev_key, k], dim=1)
src_len = k.size(1)
if "prev_value" in saved_state:
_prev_value = saved_state["prev_value"]
assert _prev_value is not None
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
if static_kv:
v = prev_value
else:
assert v is not None
v = torch.cat([prev_value, v], dim=1)
prev_key_padding_mask: Optional[Tensor] = None
if "prev_key_padding_mask" in saved_state:
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
assert k is not None and v is not None
key_padding_mask = MultiheadAttention._append_prev_key_padding_mask(
key_padding_mask=key_padding_mask,
prev_key_padding_mask=prev_key_padding_mask,
batch_size=bsz,
src_len=k.size(1),
static_kv=static_kv,
)
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
saved_state["prev_key_padding_mask"] = key_padding_mask
# In this branch incremental_state is never None
assert incremental_state is not None
incremental_state = self._set_input_buffer(incremental_state, saved_state)
assert k is not None
assert k.size(1) == src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if key_padding_mask is not None and key_padding_mask.dim() == 0:
key_padding_mask = None
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
assert v is not None
src_len += 1
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
if attn_mask is not None:
attn_mask = torch.cat(
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
)
if key_padding_mask is not None:
key_padding_mask = torch.cat(
[
key_padding_mask,
torch.zeros(key_padding_mask.size(0), 1).type_as(
key_padding_mask
),
],
dim=1,
)
attn_weights = torch.bmm(q, k.transpose(1, 2))
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
if position_bias is not None: ## first order
## position_bias: [241, 241, 64]
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
reshape_q = q.contiguous().view(bsz * self.num_heads, -1, self.head_dim).transpose(0,1) #[241, 492, 64]
#print ("reshape_q: ", reshape_q.size())
B = torch.matmul(reshape_q, position_bias.transpose(-2, -1))
#print ("B: ", B.size()) ## [241, 492, 241]
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
B = B.transpose(0, 1).view(bsz*self.num_heads, position_bias.size(0), position_bias.size(1))
#print ("B 2: ", B.size())
attn_weights += B
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
attn_mask = attn_mask.unsqueeze(0)
if self.onnx_trace:
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
attn_weights += attn_mask
if key_padding_mask is not None:
# don't attend to padding symbols
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
if not is_tpu:
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
float("-inf"),
)
else:
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.masked_fill(key_padding_mask, float("-inf"))
attn_weights = attn_weights.transpose(0, 2)
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
if before_softmax:
return attn_weights, v
attn_weights_float = utils.softmax(
attn_weights, dim=-1, onnx_trace=self.onnx_trace
)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = self.dropout_module(attn_weights)
assert v is not None
attn = torch.bmm(attn_probs, v)
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
if self.onnx_trace and attn.size(1) == 1:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn = self.out_proj(attn)
attn_weights: Optional[Tensor] = None
if need_weights:
attn_weights = attn_weights_float.view(
bsz, self.num_heads, tgt_len, src_len
).transpose(1, 0)
if not need_head_weights:
# average attention weights over heads
attn_weights = attn_weights.mean(dim=0)
return attn, attn_weights
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import torch
class RelativePositionalEncoding(torch.nn.Module):
def __init__(self, d_model, maxlen=1000, embed_v=False):
super(RelativePositionalEncoding, self).__init__()
self.d_model = d_model
self.maxlen = maxlen
self.pe_k = torch.nn.Embedding(2*maxlen, d_model)
if embed_v:
self.pe_v = torch.nn.Embedding(2*maxlen, d_model)
self.embed_v = embed_v
def forward(self, pos_seq, incremental_state=None):
pos_seq[pos_seq < -self.maxlen] = -self.maxlen
pos_seq[pos_seq >= self.maxlen] = self.maxlen - 1
pos_seq = pos_seq + self.maxlen
if incremental_state is not None:
pos_seq = pos_seq[-1:]
if self.embed_v:
return self.pe_k(pos_seq), self.pe_v(pos_seq)
else:
return self.pe_k(pos_seq), None
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
import math
from typing import Any, Dict, List, Optional
import torch
import torch.nn as nn
from fairseq import utils
from fairseq.distributed import fsdp_wrap
from fairseq.models import FairseqIncrementalDecoder
from fairseq.models.transformer import TransformerConfig
from fairseq.models.transformer.transformer_decoder import module_name_fordropout, Linear
from fairseq.modules import (
AdaptiveSoftmax,
BaseLayer,
FairseqDropout,
LayerDropModuleList,
LayerNorm,
PositionalEmbedding,
SinusoidalPositionalEmbedding,
)
from fairseq.modules.checkpoint_activations import checkpoint_wrapper
from fairseq.modules.quant_noise import quant_noise as apply_quant_noise_
from torch import Tensor
from speech2c.models.modules.transformer_decoder_layer import TransformerDecoderLayerBase
from speech2c.models.modules.relative_pos_enc import RelativePositionalEncoding
class TransformerDecoderBase(FairseqIncrementalDecoder):
"""
Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def __init__(
self,
cfg,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
use_rel_pos_enc=False,
):
self.cfg = cfg
super().__init__(dictionary)
self.register_buffer("version", torch.Tensor([3]))
self._future_mask = torch.empty(0)
self.dropout_module = FairseqDropout(
cfg.dropout, module_name=module_name_fordropout(self.__class__.__name__)
)
self.decoder_layerdrop = cfg.decoder.layerdrop
self.share_input_output_embed = cfg.share_decoder_input_output_embed
input_embed_dim = embed_tokens.embedding_dim
embed_dim = cfg.decoder.embed_dim
self.embed_dim = embed_dim
self.output_embed_dim = cfg.decoder.output_dim
self.padding_idx = embed_tokens.padding_idx
self.max_target_positions = cfg.max_target_positions
self.embed_tokens = embed_tokens
self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim)
if not cfg.adaptive_input and cfg.quant_noise.pq > 0:
self.quant_noise = apply_quant_noise_(
nn.Linear(embed_dim, embed_dim, bias=False),
cfg.quant_noise.pq,
cfg.quant_noise.pq_block_size,
)
else:
self.quant_noise = None
self.project_in_dim = (
Linear(input_embed_dim, embed_dim, bias=False)
if embed_dim != input_embed_dim
else None
)
self.embed_positions = (
PositionalEmbedding(
self.max_target_positions,
embed_dim,
self.padding_idx,
learned=cfg.decoder.learned_pos,
)
if not cfg.no_token_positional_embeddings
else None
)
if cfg.layernorm_embedding:
self.layernorm_embedding = LayerNorm(embed_dim, export=cfg.export)
else:
self.layernorm_embedding = None
self.cross_self_attention = cfg.cross_self_attention
self.use_rel_pos_enc = use_rel_pos_enc
if self.decoder_layerdrop > 0.0:
self.layers = LayerDropModuleList(p=self.decoder_layerdrop)
else:
self.layers = nn.ModuleList([])
self.layers.extend(
[
self.build_decoder_layer(cfg, no_encoder_attn)
for _ in range(cfg.decoder.layers)
]
)
self.num_layers = len(self.layers)
if cfg.decoder.normalize_before and not cfg.no_decoder_final_norm:
self.layer_norm = LayerNorm(embed_dim, export=cfg.export)
else:
self.layer_norm = None
self.project_out_dim = (
Linear(embed_dim, self.output_embed_dim, bias=False)
if embed_dim != self.output_embed_dim and not cfg.tie_adaptive_weights
else None
)
self.adaptive_softmax = None
self.output_projection = output_projection
if self.output_projection is None:
self.build_output_projection(cfg, dictionary, embed_tokens)
if self.use_rel_pos_enc:
self.pos_emb = RelativePositionalEncoding(self.embed_dim // cfg.decoder.attention_heads, 24)
def build_output_projection(self, cfg, dictionary, embed_tokens):
if cfg.adaptive_softmax_cutoff is not None:
self.adaptive_softmax = AdaptiveSoftmax(
len(dictionary),
self.output_embed_dim,
utils.eval_str_list(cfg.adaptive_softmax_cutoff, type=int),
dropout=cfg.adaptive_softmax_dropout,
adaptive_inputs=embed_tokens if cfg.tie_adaptive_weights else None,
factor=cfg.adaptive_softmax_factor,
tie_proj=cfg.tie_adaptive_proj,
)
elif self.share_input_output_embed:
self.output_projection = nn.Linear(
self.embed_tokens.weight.shape[1],
self.embed_tokens.weight.shape[0],
bias=False,
)
self.output_projection.weight = self.embed_tokens.weight
else:
self.output_projection = nn.Linear(
self.output_embed_dim, len(dictionary), bias=False
)
nn.init.normal_(
self.output_projection.weight, mean=0, std=self.output_embed_dim ** -0.5
)
num_base_layers = cfg.base_layers
for i in range(num_base_layers):
self.layers.insert(
((i + 1) * cfg.decoder.layers) // (num_base_layers + 1),
BaseLayer(cfg),
)
def build_decoder_layer(self, cfg, no_encoder_attn=False):
layer = TransformerDecoderLayerBase(cfg, no_encoder_attn, has_relative_attention_bias=self.use_rel_pos_enc)
checkpoint = cfg.checkpoint_activations
if checkpoint:
offload_to_cpu = cfg.offload_activations
layer = checkpoint_wrapper(layer, offload_to_cpu=offload_to_cpu)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap = cfg.min_params_to_wrap if not checkpoint else 0
layer = fsdp_wrap(layer, min_num_params=min_params_to_wrap)
return layer
def forward(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
features_only: bool = False,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
src_lengths: Optional[Any] = None,
return_all_hiddens: bool = False,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x, extra = self.extract_features(
prev_output_tokens,
encoder_out=encoder_out,
incremental_state=incremental_state,
full_context_alignment=full_context_alignment,
alignment_layer=alignment_layer,
alignment_heads=alignment_heads,
)
if not features_only:
x = self.output_layer(x)
return x, extra
def extract_features_scriptable(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]],
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs, slen = prev_output_tokens.size()
if alignment_layer is None:
alignment_layer = self.num_layers - 1
enc: Optional[Tensor] = None
padding_mask: Optional[Tensor] = None
if encoder_out is not None and len(encoder_out["encoder_out"]) > 0:
enc = encoder_out["encoder_out"][0]
assert (
enc.size()[1] == bs
), f"Expected enc.shape == (t, {bs}, c) got {enc.shape}"
if encoder_out is not None and len(encoder_out["encoder_padding_mask"]) > 0:
padding_mask = encoder_out["encoder_padding_mask"][0]
# embed positions
positions = None
if self.embed_positions is not None:
positions = self.embed_positions(
prev_output_tokens, incremental_state=incremental_state
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
if positions is not None:
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
if self.quant_noise is not None:
x = self.quant_noise(x)
if self.project_in_dim is not None:
x = self.project_in_dim(x)
if positions is not None:
x += positions
if self.layernorm_embedding is not None:
x = self.layernorm_embedding(x)
x = self.dropout_module(x)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
if self.use_rel_pos_enc:
pos_seq = torch.arange(0, slen).long().to(x.device)
pos_seq = pos_seq[:, None] - pos_seq[None, :]
pos_k, _ = self.pos_emb(pos_seq, incremental_state)
else:
pos_k = None
self_attn_padding_mask: Optional[Tensor] = None
if self.cross_self_attention or prev_output_tokens.eq(self.padding_idx).any():
self_attn_padding_mask = prev_output_tokens.eq(self.padding_idx)
# decoder layers
attn: Optional[Tensor] = None
inner_states: List[Optional[Tensor]] = [x]
for idx, layer in enumerate(self.layers):
if incremental_state is None and not full_context_alignment:
self_attn_mask = self.buffered_future_mask(x)
else:
self_attn_mask = None
x, layer_attn, _ = layer(
x,
enc,
padding_mask,
incremental_state,
self_attn_mask=self_attn_mask,
self_attn_padding_mask=self_attn_padding_mask,
need_attn=bool((idx == alignment_layer)),
need_head_weights=bool((idx == alignment_layer)),
pos_bias=pos_k,
)
inner_states.append(x)
if layer_attn is not None and idx == alignment_layer:
attn = layer_attn.float().to(x)
if attn is not None:
if alignment_heads is not None:
attn = attn[:alignment_heads]
# average probabilities over heads
attn = attn.mean(dim=0)
if self.layer_norm is not None:
x = self.layer_norm(x)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
if self.project_out_dim is not None:
x = self.project_out_dim(x)
return x, {"attn": [attn], "inner_states": inner_states}
def output_layer(self, features):
"""Project features to the vocabulary size."""
if self.adaptive_softmax is None:
# project back to size of vocabulary
return self.output_projection(features)
else:
return features
def max_positions(self):
"""Maximum output length supported by the decoder."""
if self.embed_positions is None:
return self.max_target_positions
return min(self.max_target_positions, self.embed_positions.max_positions)
def buffered_future_mask(self, tensor):
dim = tensor.size(0)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if (
self._future_mask.size(0) == 0
or (not self._future_mask.device == tensor.device)
or self._future_mask.size(0) < dim
):
self._future_mask = torch.triu(
utils.fill_with_neg_inf(torch.zeros([dim, dim])), 1
)
self._future_mask = self._future_mask.to(tensor)
return self._future_mask[:dim, :dim]
def upgrade_state_dict_named(self, state_dict, name):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
weights_key = "{}.embed_positions.weights".format(name)
if weights_key in state_dict:
del state_dict[weights_key]
state_dict[
"{}.embed_positions._float_tensor".format(name)
] = torch.FloatTensor(1)
if f"{name}.output_projection.weight" not in state_dict:
if self.share_input_output_embed:
embed_out_key = f"{name}.embed_tokens.weight"
else:
embed_out_key = f"{name}.embed_out"
if embed_out_key in state_dict:
state_dict[f"{name}.output_projection.weight"] = state_dict[
embed_out_key
]
if not self.share_input_output_embed:
del state_dict[embed_out_key]
for i in range(self.num_layers):
# update layer norms
layer_norm_map = {
"0": "self_attn_layer_norm",
"1": "encoder_attn_layer_norm",
"2": "final_layer_norm",
}
for old, new in layer_norm_map.items():
for m in ("weight", "bias"):
k = "{}.layers.{}.layer_norms.{}.{}".format(name, i, old, m)
if k in state_dict:
state_dict[
"{}.layers.{}.{}.{}".format(name, i, new, m)
] = state_dict[k]
del state_dict[k]
version_key = "{}.version".format(name)
if utils.item(state_dict.get(version_key, torch.Tensor([1]))[0]) <= 2:
# earlier checkpoints did not normalize after the stack of layers
self.layer_norm = None
self.normalize = False
state_dict[version_key] = torch.Tensor([1])
return state_dict
class TransformerDecoder(TransformerDecoderBase):
def __init__(
self,
args,
dictionary,
embed_tokens,
no_encoder_attn=False,
output_projection=None,
):
self.args = args
super().__init__(
TransformerConfig.from_namespace(args),
dictionary,
embed_tokens,
no_encoder_attn=no_encoder_attn,
output_projection=output_projection,
use_rel_pos_enc=args.use_rel_pos_enc,
)
def build_output_projection(self, args, dictionary, embed_tokens):
super().build_output_projection(
TransformerConfig.from_namespace(args), dictionary, embed_tokens
)
def build_decoder_layer(self, args, no_encoder_attn=False):
return super().build_decoder_layer(
TransformerConfig.from_namespace(args), no_encoder_attn=no_encoder_attn
)
class TransformerDecoderScriptable(TransformerDecoder):
def extract_features(
self,
prev_output_tokens,
encoder_out: Optional[Dict[str, List[Tensor]]] = None,
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
full_context_alignment: bool = False,
alignment_layer: Optional[int] = None,
alignment_heads: Optional[int] = None,
):
# call scriptable method from parent class
x, _ = self.extract_features_scriptable(
prev_output_tokens,
encoder_out,
incremental_state,
full_context_alignment,
alignment_layer,
alignment_heads,
)
return x, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment