Commit ab38d084 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

create new model

parent ecc6eb8b
Pipeline #589 failed with stages
in 0 seconds
# Conformer_PyTorch
# Conformer_Wenet_PyTorch
## 模型介绍
## 论文
Conformer模型是一种基于自注意力机制(self-attention)的序列建模方法,被广泛应用于语音识别、语言建模、机器翻译、语义分析等自然语言处理任务中。本工程使用的是wenet工具包来调用conformer模型。
`Conformer: Local Features Coupling Global Representations for Visual Recognition`
Wenet是一个开源的端到端语音识别(ASR)工具包,基于PyTorch实现,旨在提供一个简单、高效、灵活的ASR框架,帮助研究人员和开发者快速构建自己的语音识别系统。
Wenet的核心架构是基于端到端的深度神经网络(DNN)和自注意力机制(self-attention)实现的,其中包含了多种不同的模型结构和技术,如Conformer、Transformer、LSTM-TDNN等,可根据不同的任务需求进行选择。此外,Wenet还提供了一系列的工具和接口,方便用户进行模型的训练、推理、评估等操作,支持多种硬件平台和操作系统,如CPU、GPU、FPGA等。
- [https://arxiv.org/abs/2105.03889](https://arxiv.org/abs/2105.03889)
## 模型结构
Conformer模型是一种结合了Transformer的自注意力机制和卷积神经网络的模型结构,用于语音识别和自然语言处理任务,具有时域和频域特征的建模能力。
![model](./img/model.png)
## 数据集
## 算法原理
使用的数据集为Aishell,Aishell是北京壳牌壳牌科技有限公司发布的开源中文普通话语音语料库。来自中国不同口音地区的400人被邀请参加录音,这是在一个安静的室内环境中进行的,使用高保真麦克风,并将采样降至16kHz。人工抄写准确率95%以上,经过专业的语音标注和严格的质量检测。这些数据对学术使用是免费的。我们希望为语音识别领域的新研究者提供适量的数据
Conformer算法原理是通过结合多层的Transformer编码器和深度卷积神经网络,实现对输入序列的时域和频域特征进行建模,从而提高语音识别和自然语言处理任务的性能
数据集下载地址:http://openslr.org/33/
![conformer encoder](./img/conformer_encoder.png)
## 环境配置
### Docker(方法一)
## 训练及推理
此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取docker镜像的地址与使用步骤
### 环境配置
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10-py38-latest
在光源可拉取训练的docker镜像,本工程推荐的镜像如下:
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10-py38-latest
cd /path/workspace/
pip3 install typeguard==2.13.3
```
### Dockerfile(方法二)
此处提供dockerfile的使用方法
```
cd ./docker
docker build --no-cache -t conformer .
docker run -it -v /path/your_code_data/:/path/your_code_data/ --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
```
### Anaconda(方法三)
此处提供本地配置、编译的详细步骤,例如:
进入镜像后需要安装所需的三方依赖
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk22.10
python:python3.8
torch:1.10
torchvision:0.10
```
`Tips:以上dtk驱动、python、paddle等DCU相关工具版本需要严格一一对应`
其它非深度学习库参照requirements.txt安装:
```
pip3 install -r requirements.txt
pip3 install typeguard==2.13.3
```
### 数据预处理
## 数据集
`Aishell`
- [http://openslr.org/33/](http://openslr.org/33/)
此处提供数据预处理脚本的使用方法
```
#如果自行下载了aishell数据集,只需要在run.sh文件中修改数据集路径,然后执行如下指令即可
cd ./examples/aishell/s0
#设置stage为0会自动下载数据集,若有下载好的数据集,可手动设置run.sh脚本中的data路径即可省去下载过程
#设置stage为-1会自动下载数据集,若有下载好的数据集,可手动设置run.sh脚本中的data路径即可省去下载过程
bash run.sh --stage -1 --stop_stage -1
bash run.sh --stage 0 --stop_stage 0
bash run.sh --stage 1 --stop_stage 1
......@@ -46,34 +86,117 @@ bash run.sh --stage 1 --stop_stage 1
bash run.sh --stage 2 --stop_stage 2
bash run.sh --stage 3 --stop_stage 3
```
### 训练
预处理好的训练数据目录结构如下,用于正常训练的完整数据集请按此目录结构进行制备:
该工程数据集分为两个部分,一个是原始数据,另一个是索引和音频提取的特征文件
1、原始数据
```
bash train.sh
├── data_aishell
│   ├── transcript
│   │   └── aishell_transcript_v0.8.txt
│   └── wav
│   ├── dev
│   ├── test
│   └── train
├── data_aishell.tgz
├── resource_aishell
│   ├── lexicon.txt
│   └── speaker.info
└── resource_aishell.tgz
```
### 推理
2、索引和音频提取的特征文件
训练结束后,模型会保存在exp/conformer/final.pt路径下,可以直接执行如下指令查看推理结果(若需要使用其他预训练模型,请手动修改)
```
├── dev
│   ├── data.list
│   ├── text
│   └── wav.scp
├── dict
│   └── lang_char.txt
├── local
│   ├── dev
│   │   ├── text
│   │   ├── transcripts.txt
│   │   ├── utt.list
│   │   ├── wav.flist
│   │   ├── wav.scp
│   │   └── wav.scp_all
│   ├── test
│   │   ├── text
│   │   ├── transcripts.txt
│   │   ├── utt.list
│   │   ├── wav.flist
│   │   ├── wav.scp
│   │   └── wav.scp_all
│   └── train
│   ├── text
│   ├── transcripts.txt
│   ├── utt.list
│   ├── wav.flist
│   ├── wav.scp
│   └── wav.scp_all
├── test
│   ├── data.list
│   ├── text
│   └── wav.scp
└── train
├── data.list
├── global_cmvn
├── text
└── wav.scp
```
## 训练
```
# 默认是4卡,可以通过修改run_train.sh文件修改卡数
bash train.sh
```
## 推理
```
# 默认使用exp/conformer/final.pt进行推理,可以手动修改
bash validate.sh
```
## 模型精度数据
## result
此处填算法效果测试图
![img](C:\Users\baita\Downloads\doc\xxx.png)
### 精度
测试数据:[aishell](http://openslr.org/33/),使用的加速卡:Z100L。
根据测试结果情况填写表格:
| 卡数 | 数据精度 | 精度 |
| :--: | :------: | :-----: |
| 4 | fp32 | 93.1294 |
## 应用场景
### 算法类别
| 卡数 | 精度 |
| :--: | :-----: |
| 4卡 | 93.1294 |
`语音识别`
### 热点应用行业
`金融、通信、广媒`
## 源码仓库及问题反馈
http://developer.hpccube.com/codes/modelzoo/conformer_pytorch.git
- [https://developer.hpccube.com/codes/modelzoo/conformer_pytorch](https://developer.hpccube.com/codes/modelzoo/conformer_pytorch)
## 参考
## 参考资料
https://github.com/wenet-e2e/wenet
\ No newline at end of file
- https://github.com/wenet-e2e/wenet
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-22.10-py38-latest
RUN source /opt/dtk/env.sh
COPY requirments.txt requirments.txt
RUN pip3 install -r requirements.txt
pip3 install typeguard==2.13.3
......@@ -66,7 +66,7 @@ dataset_conf:
grad_clip: 5
accum_grad: 4
max_epoch: 240
max_epoch: 50
log_interval: 100
optim: adam
......
......@@ -24,7 +24,7 @@ num_nodes=1
node_rank=0
# The aishell dataset location, please change this to your own path
# make sure of using absolute path. DO-NOT-USE relatvie path!
data=/data/conformer/train/
data=/parastor/home/lidc/Modelzoo/conformer_pytorch/data
data_url=www.openslr.org/resources/33
nj=16
......
......@@ -52,6 +52,7 @@ checkpoint=
# use average_checkpoint will get better result
average_checkpoint=false
decode_checkpoint=$dir/final.pt
average_num=30
decode_modes="attention_rescoring"
......@@ -68,7 +69,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
{
test_dir=$dir/test_${mode}
mkdir -p $test_dir
python wenet/bin/recognize.py --gpu 0 \
python3 wenet/bin/recognize.py --gpu 4 \
--mode $mode \
--config $dir/train.yaml \
--data_type $data_type \
......
#模型编码
modelCode=428
# 模型名称
modelName=LPR
modelName=conformer_pytorch
# 模型描述
modelDescription=LPR是一个基于深度学习技术的车牌识别模型,主要识别目标是自然场景的车牌图像
modelDescription=Conformer模型是一种结合了Transformer的自注意力机制和卷积神经网络的模型结构,用于语音识别和自然语言处理任务,具有时域和频域特征的建模能力。
# 应用场景(多个标签以英文逗号分割)
appScenario=OCR,车牌识别,目标检测,训练,推理,pretrain,train,inference
appScenario=语音识别,金融,通信,广媒
# 框架类型(多个标签以英文逗号分割)
frameType=PyTorch,Migraphx,ONNXRuntime
frameType=PyTorch
......@@ -158,7 +158,7 @@ def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
os.environ['HIP_VISIBLE_DEVICES'] = str(args.gpu)
if args.mode in ['ctc_prefix_beam_search', 'attention_rescoring'
] and args.batch_size > 1:
......@@ -207,7 +207,7 @@ def main():
# Init asr model from configs
model = init_model(configs)
#print('############################')
# Load dict
char_dict = {v: k for k, v in symbol_table.items()}
eos = len(char_dict) - 1
......@@ -216,7 +216,7 @@ def main():
use_cuda = args.gpu >= 0 and torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
model = model.to(device)
#print('model to device############')
model.eval()
with torch.no_grad(), open(args.result_file, 'w') as fout:
for batch_idx, batch in enumerate(test_data_loader):
......@@ -309,8 +309,9 @@ def main():
simulate_streaming=args.simulate_streaming)
hyps = [hyp]
elif args.mode == 'attention_rescoring':
#print('11111111111 attention_resoring 1111111111111111')
assert (feats.size(0) == 1)
hyp, _ = model.attention_rescoring(
hyp, source = model.attention_rescoring(
feats,
feats_lengths,
args.beam_size,
......@@ -320,6 +321,8 @@ def main():
simulate_streaming=args.simulate_streaming,
reverse_weight=args.reverse_weight)
hyps = [hyp]
#print(hyps)
#print(source)
elif args.mode == 'hlg_onebest':
hyps = model.hlg_onebest(
feats,
......@@ -349,7 +352,7 @@ def main():
if w == eos:
break
content.append(char_dict[w])
#logging.info('{} {}'.format(key, args.connect_symbol.join(content)))
logging.info('{} {}'.format(key, args.connect_symbol.join(content)))
fout.write('{} {}\n'.format(key, args.connect_symbol.join(content)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment