"vscode:/vscode.git/clone" did not exist on "804735bf86569e626cd1be9eb2cf6ebefa13d3ca"
Commit b97afd54 authored by wangwei990215's avatar wangwei990215
Browse files

Initial commit

parents
Pipeline #1825 failed with stages
in 0 seconds
[format]
pretty = "%h %an %s"
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
The MooER from Moore Threads is licensed under the MIT License listed below. Copyright (c) 2023-2024 Moore Threads Technology Co., Ltd("Moore Threads"). All rights reserved.
Terms of the MIT License
-------------------------------------------------------------------------
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-------------------------------------------------------------------------
The following copyright statements and licenses apply to various open source software/model
packages (or portions thereof) that are distributed with this MooER. MooER that
includes this file does not necessarily use all the open source software packages referred
to below and may also only use portions of a given package. Some open source software
packages referred to below may have been modified by Moore Threads Technology Co., Ltd
-------------------------------------------------------------------------
SLAM-LLM
Copyright (c) 2024 Ziyang Ma
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-------------------------------------------------------------------------
FunASR
MIT License
Copyright (c) 2022 Alibaba
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
-------------------------------------------------------------------------
FunASR model
Copyright (C) [2023-2028] Alibaba Group. All rights reserved.
Thank you for choosing the FunASR open source models. The FunASR open source models contain a series of open-source models that allow everyone to use, modify, share, and learn from it.
To ensure better community collaboration, we have developed the following agreement and hope that you carefully read and abide by it.
1 Definitions
In this agreement, [FunASR software] refers to the FunASR open source model, and its derivatives, including fine-tuned models. [You] refer to individuals or organizations who use, modify, share, and learn from [FunASR software].
2 License and Restrictions
2.1 License
You are free to use, copy, modify, and share [FunASR software] under the conditions of this agreement.
2.2 Restrictions
You should indicate the code and model source and author information when using, copying, modifying and sharing [FunASR software]. You should keep the relevant names of models in [FunASR software].
3 Responsibility and Risk
[FunASR software] is for reference and learning purposes only and is not responsible for any direct or indirect losses caused by your use or modification of [FunASR software]. You should take responsibility and risks for your use and modification of [FunASR software].
4 Termination
If you violate any terms of this agreement, your license will be automatically terminated, and you must stop using, copying, modifying, and sharing [FunASR software].
5 Revision
This agreement may be updated and revised from time to time. The revised agreement will be published in the FunASR official repository and automatically take effect. If you continue to use, copy, modify, and share [FunASR software], it means you agree to the revised agreement.
6 Other Provisions
This agreement is subject to the laws of [Country/Region]. If any provisions are found to be illegal, invalid, or unenforceable, they shall be deemed deleted from this agreement, and the remaining provisions shall remain valid and binding.
If you have any questions or comments about this agreement, please contact us.
Copyright (c) [2023-2028] Alibaba Group. All rights reserved.
FunASR 模型开源协议
版本号:1.0
版权所有 (C) [2023-2028] [阿里巴巴集团]。保留所有权利。
感谢您选择 FunASR 开源模型。FunASR 开源模型包含一系列免费且开源的工业模型,让大家可以使用、修改、分享和学习该模型。
为了保证更好的社区合作,我们制定了以下协议,希望您仔细阅读并遵守本协议。
1 定义
本协议中,[FunASR 软件]指 FunASR 开源模型权重及其衍生品,包括 Finetune 后的模型;[您]指使用、修改、分享和学习[FunASR 软件]的个人或组织。
2 许可和限制
2.1 许可
您可以在遵守本协议的前提下,自由地使用、复制、修改和分享[FunASR 软件]。
2.2 限制
您在使用、复制、修改和分享[FunASR 软件]时,必须注明出处以及作者信息,并保留[FunASR 软件]中相关模型名称。
3 责任和风险承担
[FunASR 软件]仅作为参考和学习使用,不对您使用或修改[FunASR 软件]造成的任何直接或间接损失承担任何责任。您对[FunASR 软件]的使用和修改应该自行承担风险。
4 终止
如果您违反本协议的任何条款,您的许可将自动终止,您必须停止使用、复制、修改和分享[FunASR 软件]。
5 修订
本协议可能会不时更新和修订。修订后的协议将在[FunASR 软件]官方仓库发布,并自动生效。如果您继续使用、复制、修改和分享[FunASR 软件],即表示您同意修订后的协议。
6 其他规定
本协议受到[国家/地区] 的法律管辖。如果任何条款被裁定为不合法、无效或无法执行,则该条款应被视为从本协议中删除,而其余条款应继续有效并具有约束力。
如果您对本协议有任何问题或意见,请联系我们。
版权所有© [2023-2028] [阿里巴巴集团]。保留所有权利。
-------------------------------------------------------------------------
Qwen/Qwen2-7B-Instruct
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
# MooER_pytorch
## 论文
- https://arxiv.org/abs/2408.05101
## 模型结构
MooER模型是一个由摩尔线程开发的、基于大语言模型(Large Language Model,LLM)的语音识别和语音翻译系统。模型结构如图:<br>
![模型结构](images/model_structure.png)
## 算法原理
通过摩耳框架,您可以基于大语言模型(Large Language Model,LLM),以端到端的方式,将输入语音自动转录为文本(即语音识别),并将其翻译为其它语言(即语音翻译)
## 环境配置
### Docker(方法一)
此处提供[光源](https://sourcefind.cn/#/main-page)拉取镜像的地址与使用步骤
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.2-py3.10
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
# 安装依赖项:
pip install -r requirements.txt
```
### Dockerfile(方法二)
此处提供Dockerfile的使用方法
```
cd ./docker
docker build --no-cache -t mooer:latest
docker run -it -v /path/your_code_data/:/path/your_code_data/ -v /opt/hyhal:/opt/hyhal:ro --shm-size=32G --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name docker_name imageID bash
pip install -r requirements.txt
```
### Anaconda(方法三)
关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/
```
DTK软件栈:dtk24,04,2
Python:3.10
touch:2.3.0
torchaudio:2.1.2
```
Tips:以上dtk驱动、python、paddle等DCU相关工具版本需要严格一一对应
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt
```
## 数据集
## 训练
## 推理
1:首先从[ModelScope](https://modelscope.cn/models/MooreThreadsSpeech/MooER-MTL-5K)或者[HF-Mirror](https://hf-mirror.com/mtspeech/MooER-MTL-5K)下载官方提供的预训练模型。
```
# 使用ModelScope
git lfs clone https://modelscope.cn/models/MooreThreadsSpeech/MooER-MTL-5K
# 使用HF-Mirror
git lfs clone https://hf-mirror.com/mtspeech/MooER-MTL-5K
```
将下载后的文件放置在 `pretrained_models` 文件夹中。
```shell
cp MooER-MTL-5K/* pretrained_models
```
2:下载 [`Qwen2-7B-Instruct`](https://modelscope.cn/models/qwen/qwen2-7b-instruct)
```
# 使用ModelScope
git lfs clone https://modelscope.cn/models/qwen/qwen2-7b-instruct
# 使用HuggingFace
git lfs clone https://hf-mirror.com/Qwen/Qwen2-7B-Instruct
```
将下载后的文件放在 `pretrained_models/Qwen2-7B-Instruct` 文件夹中。
最后,确保下载的文件按照下面的文件结构放置。模型文件损坏或安放位置不正确会导致运行出错。
```text
./pretrained_models/
|-- paraformer_encoder
| |-- am.mvn
| `-- paraformer-encoder.pth
|-- asr
| |-- adapter_project.pt
| `-- lora_weights
| |-- README.md
| |-- adapter_config.json
| `-- adapter_model.bin
|-- ast
| |-- adapter_project.pt
| `-- lora_weights
| |-- README.md
| |-- adapter_config.json
| `-- adapter_model.bin
|-- asr_ast_mtl
| |-- adapter_project.pt
| `-- lora_weights
| |-- README.md
| |-- adapter_config.json
| `-- adapter_model.bin
|-- Qwen2-7B-Instruct
| |-- model-00001-of-00004.safetensors
| |-- model-00002-of-00004.safetensors
| |-- model-00003-of-00004.safetensors
| |-- model-00004-of-00004.safetensors
| |-- model.safetensors.index.json
| |-- config.json
| |-- configuration.json
| |-- generation_config.json
| |-- merges.txt
| |-- tokenizer.json
| |-- tokenizer_config.json
| |-- vocab.json
| |-- LICENSE
| `-- README.md
|-- README.md
`-- configuration.json
```
3:最后,在上述工作准备好后可以执行代码进行推理:<br>
`demo`文件夹下提供了一个示例语音文件用于测试。<br>
首先设置环境变量:
```
# 设置环境变量
export PYTHONIOENCODING=UTF-8
export LC_ALL=C
export PYTHONPATH=$PWD/src:$PYTHONPATH
```
- 同时进行ASR和AST:
```
# 使用指定的音频文件
python inference.py --wav_path /path/to/your_audio_file
```
上述命令会调用一个多任务Mooer大模型,同时输出语音识别和语音翻译的结果。如果运行成功,将在终端看到如下结果。<br>
![asr-adn-ast](images/asr-and-ast.png)
- 指定语音识别模型,仅输出识别结果:
```
python inference.py --task asr \
--cmvn_path pretrained_models/paraformer_encoder/am.mvn \
--encoder_path pretrained_models/paraformer_encoder/paraformer-encoder.pth \
--llm_path pretrained_models/Qwen2-7B-Instruct \
--adapter_path pretrained_models/asr/adapter_project.pt \
--lora_dir pretrained_models/asr/lora_weights \
--wav_path /path/to/your_audio_file
```
上述命令会仅输出语音识别的结果。如果运行成功,将在终端看到如下结果。<br>
![only-asr](images/only-asr.png)
- 指定语音翻译模型,仅输出中译英结果
```
python inference.py --task ast \
--cmvn_path pretrained_models/paraformer_encoder/am.mvn \
--encoder_path pretrained_models/paraformer_encoder/paraformer-encoder.pth \
--llm_path pretrained_models/Qwen2-7B-Instruct \
--adapter_path pretrained_models/ast/adapter_project.pt \
--lora_dir pretrained_models/ast/lora_weights \
--wav_path /path/to/your_audio_file
```
上述命令会仅输出语音翻译的结果。如果运行成功,将在终端看到如下结果。<br>
![only-ast](images/only-ast.png)
## 应用场景
### 算法分类
语音识别、语音翻译
### 热点应用行业
语音识别、语音翻译、教育、医疗
## 源码仓库及问题反馈
https://developer.sourcefind.cn/codes/modelzoo/mooer_pytorch
## 参考资料
https://github.com/MooreThreads/MooER
\ No newline at end of file
import time
import sox
import torch
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
import os
import argparse
import torchaudio
from torchaudio.transforms import Resample
import logging
from mooer.datasets.speech_processor import *
from mooer.configs import asr_config
from mooer.models import mooer_model
from mooer.utils.utils import *
import gradio as gr
from transformers import TextIteratorStreamer
from threading import Thread
parser = argparse.ArgumentParser()
parser.add_argument("--task", default='mtl', choices=['asr', 'ast', 'mtl'], type=str, help="task: asr or ast. Please set ast if you choose a asr/ast multitask model")
parser.add_argument("--cmvn_path", default='', type=str, help="cmvn path. If not set, will use path in src/mooer/configs/asr_config.py")
parser.add_argument("--encoder_path", default='', type=str, help="encoder path. If not set, will use the path in src/mooer/configs/asr_config.py")
parser.add_argument("--llm_path", default='', type=str, help="llm path. If not set, will use the path in src/mooer/configs/asr_config.py")
parser.add_argument("--adapter_path", default='pretrained_models/asr_ast_mtl/adapter_project.pt', type=str, help="asr/ast multitask adapter path.")
parser.add_argument("--lora_dir", default='pretrained_models/asr_ast_mtl/lora_weights', type=str, help="asr/ast multitask lora path.")
parser.add_argument("--server_port", default=10010, type=int, help="gradio server port")
parser.add_argument("--server_name", default="0.0.0.0", type=str, help="gradio server name")
parser.add_argument("--share", default=False, type=lambda x: (str(x).lower() == 'true'), help="whether to share the server to public")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filemode='w'
)
logger = logging.getLogger()
logger.setLevel(logging.INFO)
PROMPT_TEMPLATE_DICT = {
'qwen': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
}
PROMPT_DICT = {
'asr': "Transcribe speech to text. ",
'ast': "Translate speech to english text. ",
}
global_task = args.task
model_config = {
global_task: asr_config.ModelConfig(),
}
if args.llm_path and os.path.exists(args.llm_path):
model_config[global_task].llm_path = args.llm_path
if args.cmvn_path and os.path.exists(args.cmvn_path):
model_config[global_task].cmvn_path = args.cmvn_path
if args.encoder_path and os.path.exists(args.encoder_path):
model_config[global_task].encoder_path = args.encoder_path
if args.adapter_path and os.path.exists(args.adapter_path):
model_config[global_task].adapter_path = args.adapter_path
if args.lora_dir and os.path.exists(args.lora_dir):
model_config[global_task].lora_dir = args.lora_dir
if args.task:
model_config[global_task].prompt_key = 'ast' if args.task == 'mtl' else args.task
device = str(get_device())
logger.info("This demo will run on {}".format(device.upper()))
model = {}
for index, task in enumerate(model_config):
logger.info(model_config[task])
this_model, this_tokenizer = mooer_model.init_model(
model_config=model_config[task])
model[task] = {
"model": this_model,
"tokenizer": this_tokenizer
}
model[task]['model'].to(device+f':{index}')
model[task]['model'].eval()
model[task]['device'] = device+f':{index}'
# shared models and parameters
prompt_template_key = model_config[global_task].get('prompt_template_key', 'qwen')
prompt_template = PROMPT_TEMPLATE_DICT[prompt_template_key]
prompt_key = model_config[global_task].get('prompt_key', 'asr')
prompt_org = PROMPT_DICT[prompt_key]
cmvn = load_cmvn(model_config[global_task].get('cmvn_path'))
adapter_downsample_rate = model_config[global_task].get('adapter_downsample_rate')
logger.info(f"Use LLM Type {prompt_template_key}, "
f"Prompt template {prompt_template}, "
f"Use task type {prompt_key}, "
f"Prompt {prompt_org}")
load_dtype = model_config[global_task].get('load_dtype', 'bfloat16')
dtype = torch.float32
if load_dtype == 'float16':
dtype = torch.float16
elif load_dtype == 'bfloat16':
dtype = torch.bfloat16
logging.info(f"Input data type: {dtype}")
context_scope = torch.musa.amp.autocast if 'musa' in device else torch.cuda.amp.autocast
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
def process_wav(task, wav_path):
audio_raw, sample_rate = torchaudio.load(wav_path)
assert sample_rate == 16000 and audio_raw.shape[0] == 1
audio_raw = audio_raw[0]
duration = audio_raw.shape[0] / 16000.
prompt = prompt_template.format(prompt_org)
audio_mel = compute_fbank(waveform=audio_raw)
audio_mel = apply_lfr(inputs=audio_mel, lfr_m=7, lfr_n=6)
audio_mel = apply_cmvn(audio_mel, cmvn=cmvn)
audio_length = audio_mel.shape[0]
audio_length = audio_length // adapter_downsample_rate
audio_pseudo = torch.full((audio_length,), -1)
prompt_ids = model[task]["tokenizer"].encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio, prompt]
example_mask = example_ids.ge(-1)
items = {
"input_ids": example_ids,
"attention_mask": example_mask,
"audio_mel": audio_mel,
"audio_length": audio_length,
"prompt_length": prompt_length,
"duration": duration,
}
return items
def unify_forward(task, audio_file):
this_tokenizer = model[task]['tokenizer']
this_model = model[task]['model']
this_device = model[task]['device']
overall_st = time.time()
with torch.no_grad():
st = time.time()
convert(audio_file, audio_file + '.16k.wav')
audio_file = audio_file + '.16k.wav'
items = process_wav(task, audio_file)
et = time.time()
logger.info(f"Process wav takes {et - st}s")
st = time.time()
batch = process_batch([items], tokenizer=this_tokenizer)
et = time.time()
logger.info(f"Process batch takes {et - st}s")
st = time.time()
for key in batch.keys():
batch[key] = batch[key].to(this_device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
model_outputs = this_model.generate(**batch)
et = time.time()
logger.info(f"Forward takes {et - st}s")
st = time.time()
output_text = this_model.tokenizer.batch_decode(
model_outputs, add_special_tokens=False, skip_special_tokens=True)
et = time.time()
logger.info(f"Decode takes {et - st}s")
asr_text = ''
ast_text = ''
for text in output_text:
if task == 'asr':
asr_text = text
ast_text = ''
elif task == 'ast':
asr_text = ''
ast_text = text
elif task == 'mtl':
if '\n' in text:
asr_text = text.split('\n')[0]
ast_text = text.split('\n')[1]
else:
asr_text = text
ast_text = ''
overall_et = time.time()
logger.info("Cost {}s to do the inference.".format(overall_et - overall_st))
return asr_text, ast_text
def mtl_inference(mic_input, file_input):
task = global_task
try:
if mic_input is not None:
asr_res, ast_res = unify_forward(task, mic_input)
elif file_input is not None:
asr_res, ast_res = unify_forward(task, file_input)
else:
logger.info("Empty input")
return '', ''
return asr_res, ast_res
except Exception as e:
logger.error(e)
return '', ''
def unify_forward_stream(task, audio_file):
this_tokenizer = model[task]['tokenizer']
this_model = model[task]['model']
this_device = model[task]['device']
with torch.no_grad():
convert(audio_file, audio_file + '.16k.wav')
audio_file = audio_file + '.16k.wav'
items = process_wav(task, audio_file)
batch = process_batch([items], tokenizer=this_tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(this_device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
inputs_embeds, attention_mask, kwargs = this_model.generate(**batch, compute_llm=False)
streamer = TextIteratorStreamer(
tokenizer=this_tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True
)
def generate_and_signal_complete():
this_model.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 500),
num_beams=kwargs.get("num_beams", 1),
do_sample=kwargs.get("do_sample", False),
min_length=kwargs.get("min_length", 1),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=this_model.tokenizer.bos_token_id,
eos_token_id=this_model.tokenizer.eos_token_id,
pad_token_id=this_model.tokenizer.pad_token_id,
streamer=streamer
)
t1 = Thread(target=generate_and_signal_complete)
t1.start()
partial_text = ""
for new_text in streamer:
partial_text += new_text
if task == 'asr':
asr_text = partial_text
ast_text = ''
elif task == 'ast':
asr_text = ''
ast_text = partial_text
elif task == 'mtl':
if '\n' in partial_text:
asr_text = partial_text.split('\n')[0]
ast_text = partial_text.split('\n')[1]
else:
asr_text = partial_text
ast_text = ''
yield asr_text, ast_text
def mtl_inference_stream(mic_input, file_input):
task = global_task
try:
if mic_input is not None:
yield from unify_forward_stream(task, mic_input)
elif file_input is not None:
yield from unify_forward_stream(task, file_input)
else:
logger.info("Empty input")
return '', ''
except Exception as e:
logger.error(e)
return '', ''
logo = '''
<div style="width: 130px;">
<img src="https://mt-ai-speech-public.tos-cn-beijing.volces.com/MTLogo.png" width="130">
</div>
'''
description = '''
# MooER 摩耳
*MooER* is an LLM-based speech recognition/translation model capable of transcribing input speech into text and translating it into another language in an end-to-end manner.
For more details, please refer to [the repo](https://github.com/MooreThreads/MooER).
Please note that the current version DOES NOT SUPPORT mobile phones. Use your PC or Mac instead.
'''
with gr.Blocks(title="MooER online demo") as interface:
gr.HTML(logo)
gr.Markdown(description)
with gr.Row():
mic_input = gr.Audio(
sources='microphone',
type="filepath",
label="record your voice",
show_download_button=True,
)
file_input = gr.Audio(sources="upload", type="filepath", label="upload a file", show_download_button=True)
with gr.Column():
text_output_asr = gr.Textbox(label="Speech Recognition", lines=3, max_lines=10)
text_output_ast = gr.Textbox(label="Speech Translation", lines=3, max_lines=10)
with gr.Row():
mtl_btn = gr.Button("Transcribe / Translate")
mtl_btn_stream = gr.Button("Transcribe / Translate in streaming mode. Faster but less accurate.")
mtl_btn.click(fn=mtl_inference, inputs=[mic_input, file_input], outputs=[text_output_asr, text_output_ast], concurrency_id="mtl")
mtl_btn_stream.click(fn=mtl_inference_stream, inputs=[mic_input, file_input], outputs=[text_output_asr, text_output_ast], concurrency_id="mtl")
interface.queue().launch(
favicon_path='demo/resources/mt_favicon.png',
server_name=args.server_name,
server_port=args.server_port,
share=args.share
)
import time
import torch
import gradio as gr
import sox
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
import os
import argparse
import torchaudio
from torchaudio.transforms import Resample
import logging
from mooer.datasets.speech_processor import *
from mooer.configs import asr_config
from mooer.models import mooer_model
from mooer.utils.utils import *
from mooer.models.hifigan import save_wav, get_hifigan_model, get_speaker_encoder, encode_prompt_wav
parser = argparse.ArgumentParser()
parser.add_argument("--task", default='s2s_chat', choices=['asr', 'ast', 's2s_trans', 's2s_chat'],
type=str, help="task: asr or ast or s2s_trans or s2s_chat. "
"Please set ast if you choose a asr/ast/s2s_trans/s2s_chat multitask model")
parser.add_argument("--cmvn_path", default='', type=str, help="cmvn path.")
parser.add_argument("--encoder_path", default='', type=str, help="encoder path.")
parser.add_argument("--llm_path", default='', type=str, help="llm path.")
parser.add_argument("--adapter_path", default='', type=str, help="adapter path.")
parser.add_argument("--lora_dir", default='', type=str, help="lora path.")
parser.add_argument("--vocoder_path", default='', type=str, help="vocoder path")
parser.add_argument("--spk_encoder_path", default='', type=str, help="spk encoder path")
parser.add_argument("--prompt_wav_path", default='', type=str, help="prompt wav path")
parser.add_argument("--server_port", default=10010, type=int, help="gradio server port")
parser.add_argument("--server_name", default="0.0.0.0", type=str, help="gradio server name")
parser.add_argument("--share", default=False, type=lambda x: (str(x).lower() == 'true'), help="whether to share the server to public")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filemode='w'
)
PROMPT_TEMPLATE_DICT = {
'qwen': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
}
PROMPT_DICT = {
'asr': "Transcribe speech to text. ",
'ast': "Translate speech to english text. ",
's2s_trans': "Translate speech to english speech. ",
's2s_chat': "Answer my question with speech. "
}
model_config = asr_config.ModelConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# replace path
if args.llm_path and os.path.exists(args.llm_path):
model_config.llm_path = args.llm_path
if args.encoder_path and os.path.exists(args.encoder_path):
model_config.encoder_path = args.encoder_path
if args.adapter_path and os.path.exists(args.adapter_path):
model_config.adapter_path = args.adapter_path
if args.lora_dir and os.path.exists(args.lora_dir):
model_config.lora_dir = args.lora_dir
if args.cmvn_path and os.path.exists(args.cmvn_path):
model_config.cmvn_path = args.cmvn_path
if args.task:
model_config.prompt_key = args.task
device = str(get_device())
logger.info("This demo will run on {}".format(device.upper()))
logger.info(model_config)
model, tokenizer = mooer_model.init_model(
model_config=model_config)
AUDIO_START_TOKEN_INDEX = tokenizer.get_vocab()['<|audio_start|>']
model.to(device)
model.eval()
# data process
prompt_template_key = model_config.get('prompt_template_key', 'qwen')
prompt_template = PROMPT_TEMPLATE_DICT[prompt_template_key]
prompt_key = model_config.get('prompt_key', 'asr')
prompt_org = PROMPT_DICT[prompt_key]
logger.info(f"Use LLM Type {prompt_template_key}, "
f"Prompt template {prompt_template}, "
f"Use task type {prompt_key}, "
f"Prompt {prompt_org}")
cmvn = load_cmvn(model_config.get('cmvn_path'))
adapter_downsample_rate = model_config.get('adapter_downsample_rate')
hifigan_generator = get_hifigan_model(args.vocoder_path, device, decoder_dim=3584)
spk_encoder = get_speaker_encoder(args.spk_encoder_path, device)
spk_embedding = encode_prompt_wav(spk_encoder, args.prompt_wav_path, device)
def process_wav(wav_path):
audio_raw, sample_rate = torchaudio.load(wav_path)
if sample_rate != 16000:
# resample the data
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
audio_raw = resampler(audio_raw)
if audio_raw.shape[0] > 1:
# convert to mono
audio_raw = audio_raw.mean(dim=0, keepdim=True)
audio_raw = audio_raw[0]
prompt = prompt_template.format(prompt_org)
audio_mel = compute_fbank(waveform=audio_raw)
audio_mel = apply_lfr(inputs=audio_mel, lfr_m=7, lfr_n=6)
audio_mel = apply_cmvn(audio_mel, cmvn=cmvn)
audio_length = audio_mel.shape[0]
audio_length = audio_length // adapter_downsample_rate
audio_pseudo = torch.full((audio_length,), -1)
prompt_ids = tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio, prompt]
example_mask = example_ids.ge(-1)
items = {
"input_ids": example_ids,
"attention_mask": example_mask,
"audio_mel": audio_mel,
"audio_length": audio_length,
"prompt_length": prompt_length,
}
return items
load_dtype = model_config.get('load_dtype', 'bfloat16')
dtype = torch.float32
if load_dtype == 'float16':
dtype = torch.float16
elif load_dtype == 'bfloat16':
dtype = torch.bfloat16
logging.info(f"Input data type: {dtype}")
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
def inference(audio_file):
audio_file_out = audio_file + '.16k.wav'
convert(audio_file, audio_file_out)
audio_file_out_tts = audio_file + '.tts.wav'
with torch.no_grad():
try:
items = process_wav(audio_file_out)
batch = process_batch([items], tokenizer=tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with torch.cuda.amp.autocast(dtype=dtype):
inputs_embeds, attention_mask, kwargs = model.generate(**batch, compute_llm=False)
prompt_and_encoding_length = inputs_embeds.shape[1]
model_outputs = model.llm.generate(
inputs_embeds=inputs_embeds,
max_new_tokens=kwargs.get("max_new_tokens", 2000),
num_beams=kwargs.get("num_beams", 1),
do_sample=True,
min_length=kwargs.get("min_length", 1),
top_p=0.85,
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
length_penalty=kwargs.get("length_penalty", 1.0),
temperature=kwargs.get("temperature", 1.0),
attention_mask=attention_mask,
bos_token_id=model.tokenizer.bos_token_id,
eos_token_id=model.tokenizer.eos_token_id,
pad_token_id=model.tokenizer.pad_token_id,
)
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
if hasattr(model.llm.model, "embed_tokens"):
teacher_forcing_input_embeds = model.llm.model.embed_tokens(model_outputs)
teacher_forcing_input_att_mask = torch.ones((1, teacher_forcing_input_embeds.shape[1]),
dtype=torch.bool).to(device)
else:
raise NotImplementedError
inputs_embeds = torch.concat([inputs_embeds, teacher_forcing_input_embeds], dim=-2)
attention_mask = torch.concat([attention_mask, teacher_forcing_input_att_mask], dim=-1)
llm_output = model.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask,
output_hidden_states=True)
audio_start_index = prompt_and_encoding_length + model_outputs[0].tolist().index(AUDIO_START_TOKEN_INDEX)
audio_latents = llm_output.hidden_states[-1][:, audio_start_index:-6, :]
for text in output_text:
logging.info(f"{key} {text}")
text_ast = text.split("<|audio_start|>")[0]
text_ast = text_ast.replace('\\n', '\n')
save_wav(hifigan_generator, spk_embedding, audio_latents.float(), audio_file_out_tts)
return text_ast, audio_file_out_tts
except Exception as e:
logging.error(e)
return '', ''
logo = '''
<div style="width: 130px;">
<img src="https://mt-ai-speech-public.tos-cn-beijing.volces.com/MTLogo.png" width="130">
</div>
'''
description = '''
# MooER 摩耳
*MooER* [the repo](https://github.com/MooreThreads/MooER).
Please note that the current version DOES NOT SUPPORT mobile phones. Use your PC or Mac instead.
'''
with gr.Blocks(title="MooER online demo") as interface:
gr.HTML(logo)
gr.Markdown(description)
wav_path = gr.Audio(source="microphone", type="filepath")
text_output_ast = gr.Textbox(label="交互文本", lines=7, max_lines=50)
audio_output = gr.Audio(label="交互音频", type="filepath")
greet_btn = gr.Button("inference")
greet_btn.click(fn=inference, inputs=wav_path, outputs=[text_output_ast, audio_output], api_name="inference")
interface.queue().launch(
favicon_path='demo/resources/mt_favicon.png',
server_name=args.server_name,
server_port=args.server_port,
share=args.share
)
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.3.0-ubuntu22.04-dtk24.04.2-py3.10
RUN source/opt/dtk/env.sh
\ No newline at end of file
accelerate==0.33.0
appdirs
loralib
black
fire
peft
transformers==4.40.2
py7zr
optimum
sox
hydra-core
whisper
soundfile
gradio==4.41.0
lhotse==1.27.0
\ No newline at end of file
127.0.0.1 slots=8
\ No newline at end of file
icon.png

68.4 KB

import time
import torch
try:
import torch_musa
except ImportError as e:
print("You should install torch_musa if you want to run on Moore Threads GPU")
import os
import argparse
import torchaudio
from torchaudio.transforms import Resample
import logging
from mooer.datasets.speech_processor import *
from mooer.configs import asr_config
from mooer.models import mooer_model
from mooer.utils.utils import *
parser = argparse.ArgumentParser()
parser.add_argument("--wav_path", default='demo/resources/demo.wav', type=str, help="decode one wav file")
parser.add_argument("--wav_scp", default=None, type=str, help="decode scp if you want")
parser.add_argument("--task", default='ast', choices=['asr', 'ast'], type=str, help="task: asr or ast. Please set ast if you choose a asr/ast multitask model")
parser.add_argument("--batch_size", default=10, type=int, help="decode batch for scp")
parser.add_argument("--cmvn_path", default='', type=str, help="cmvn path. If not set, will use path in src/mooer/configs/asr_config.py")
parser.add_argument("--encoder_path", default='', type=str, help="encoder path. If not set, will use the path in src/mooer/configs/asr_config.py")
parser.add_argument("--llm_path", default='', type=str, help="llm path. If not set, will use the path in src/mooer/configs/asr_config.py")
parser.add_argument("--adapter_path", default='', type=str, help="adapter path. If not set, will use the path in src/mooer/configs/asr_config.py")
parser.add_argument("--lora_dir", default='', type=str, help="lora path. If not set, will use path in src/mooer/configs/asr_config.py")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
filemode='w'
)
PROMPT_TEMPLATE_DICT = {
'qwen': "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
}
PROMPT_DICT = {
'asr': "Transcribe speech to text. ",
'ast': "Translate speech to english text. ",
}
model_config = asr_config.ModelConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)
# replace path
if args.llm_path and os.path.exists(args.llm_path):
model_config.llm_path = args.llm_path
if args.encoder_path and os.path.exists(args.encoder_path):
model_config.encoder_path = args.encoder_path
if args.adapter_path and os.path.exists(args.adapter_path):
model_config.adapter_path = args.adapter_path
if args.lora_dir and os.path.exists(args.lora_dir):
model_config.lora_dir = args.lora_dir
if args.cmvn_path and os.path.exists(args.cmvn_path):
model_config.cmvn_path = args.cmvn_path
if args.task:
model_config.prompt_key = args.task
device = str(get_device())
logger.info("This demo will run on {}".format(device.upper()))
logger.info(model_config)
model, tokenizer = mooer_model.init_model(
model_config=model_config)
model.to(device)
model.eval()
# data process
prompt_template_key = model_config.get('prompt_template_key', 'qwen')
prompt_template = PROMPT_TEMPLATE_DICT[prompt_template_key]
prompt_key = model_config.get('prompt_key', 'asr')
prompt_org = PROMPT_DICT[prompt_key]
logger.info(f"Use LLM Type {prompt_template_key}, "
f"Prompt template {prompt_template}, "
f"Use task type {prompt_key}, "
f"Prompt {prompt_org}")
cmvn = load_cmvn(model_config.get('cmvn_path'))
adapter_downsample_rate = model_config.get('adapter_downsample_rate')
def process_wav(wav_path):
audio_raw, sample_rate = torchaudio.load(wav_path)
if sample_rate != 16000:
# resample the data
resampler = Resample(orig_freq=sample_rate, new_freq=16000)
audio_raw = resampler(audio_raw)
if audio_raw.shape[0] > 1:
# convert to mono
audio_raw = audio_raw.mean(dim=0, keepdim=True)
audio_raw = audio_raw[0]
prompt = prompt_template.format(prompt_org)
audio_mel = compute_fbank(waveform=audio_raw)
audio_mel = apply_lfr(inputs=audio_mel, lfr_m=7, lfr_n=6)
audio_mel = apply_cmvn(audio_mel, cmvn=cmvn)
audio_length = audio_mel.shape[0]
audio_length = audio_length // adapter_downsample_rate
audio_pseudo = torch.full((audio_length,), -1)
prompt_ids = tokenizer.encode(prompt)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64)
example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio, prompt]
example_mask = example_ids.ge(-1)
items = {
"input_ids": example_ids,
"attention_mask": example_mask,
"audio_mel": audio_mel,
"audio_length": audio_length,
"prompt_length": prompt_length,
}
return items
load_dtype = model_config.get('load_dtype', 'bfloat16')
dtype = torch.float32
if load_dtype == 'float16':
dtype = torch.float16
elif load_dtype == 'bfloat16':
dtype = torch.bfloat16
logging.info(f"Input data type: {dtype}")
context_scope = torch.musa.amp.autocast if 'musa' in device else torch.cuda.amp.autocast
with torch.no_grad():
if args.wav_scp is not None and os.path.exists(args.wav_scp):
batch_size = args.batch_size
infer_time = []
items = parse_key_text(args.wav_scp)
uttids = list(items.keys())
num_batches = len(uttids) // batch_size + (0 if len(uttids) % batch_size == 0 else 1)
for i in range(num_batches):
batch_uttids = uttids[i * batch_size:(i + 1) * batch_size]
batch_wav_paths = [items[uttid] for uttid in batch_uttids]
samples = []
for wav_path in batch_wav_paths:
samples.append(process_wav(wav_path))
batch = process_batch(samples, tokenizer=tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
ss = time.perf_counter()
model_outputs = model.generate(**batch)
infer_time.append(time.perf_counter() - ss)
logging.info(f"Infer time: {time.perf_counter() - ss}")
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
for idx, text in enumerate(output_text):
logging.info(f"uttid: {batch_uttids[idx]}")
text = text.split('\n')
if len(text) == 2:
logging.info(f"ASR: {text[0].strip()}")
logging.info(f"AST: {text[1].strip()}")
else:
logging.info(f"ASR: {text[0].strip()}")
logging.info("Total inference cost")
logging.info(sum(infer_time))
elif args.wav_path != '' and os.path.exists(args.wav_path):
try:
wav_path = args.wav_path
items = process_wav(wav_path)
batch = process_batch([items], tokenizer=tokenizer)
for key in batch.keys():
batch[key] = batch[key].to(device) if isinstance(batch[key], torch.Tensor) else batch[key]
with context_scope(dtype=dtype):
ss = time.perf_counter()
model_outputs = model.generate(**batch)
logging.info(f"Infer time: {time.perf_counter() - ss}")
output_text = model.tokenizer.batch_decode(model_outputs, add_special_tokens=False,
skip_special_tokens=True)
for text in output_text:
text = text.split('\n')
if len(text) == 2:
logging.info(f"ASR: {text[0].strip()}")
logging.info(f"AST: {text[1].strip()}")
else:
logging.info(f"ASR: {text[0].strip()}")
except Exception as e:
logging.error(e)
else:
raise IOError("You should specify --wav_scp or --wav_path as the input")
#!/bin/bash
# set your path
HOME_ROOT=/root/MooER
cd $HOME_ROOT || exit 0
test_data_dir=YOUR/testsets/root
test_sets=test-clean/test-other/aishell
decode_path=YOUR/decode/dir
export PYTHONPATH=${HOME_ROOT}/src:$PYTHONPATH
VISIBLE_DEVICES=0
################### For MUSA User #############################
# export MUSA_VISIBLE_DEVICES=$VISIBLE_DEVICES
# export DS_ACCELERATOR=musa
###############################################################
# For CUDA User
export CUDA_VISIBLE_DEVICES=$VISIBLE_DEVICES
###############################################################
python ${HOME_ROOT}/src/mooer/testing.py \
--test_config ${HOME_ROOT}/src/mooer/configs/asr_config_inference.py \
--test_data_dir $test_data_dir \
--test_sets $test_sets \
--decode_path $decode_path
# compute CER
for testset in `echo $test_sets | sed "s|/| |g"`; do
echo $testset
python ${HOME_ROOT}/src/tools/compute-wer.py --char=1 --v=1 ${test_data_dir}/${testset}/text \
${decode_path}/${testset}/text > ${decode_path}/${testset}/wer 2>&1
done
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment