Commit 7d0435cd authored by SWHL's avatar SWHL
Browse files

Update to v2.0.2

parent 159403db
......@@ -10,6 +10,11 @@
#### 更新日志
- 2023-02-11 v2.0.2 update:
- 模型和推理代码解耦(`rapid_paraformer``resources`
- 支持批量推理(通过`resources/config.yaml``batch_size`指定)
- 增加多种输入方式(`Union[str, np.ndarray, List[str]]`
- 2023-02-10 v2.0.1 update:
- 添加对输入音频为噪音或者静音的文件推理结果捕捉。
......@@ -20,38 +25,53 @@
pip install -r requirements.txt
```
2. 下载模型
- 由于模型太大(881M),上传到仓库不容易下载,提供百度云下载连接:[asr_paraformerv2.onnx](https://pan.baidu.com/s/1-nEf2eUpkzlcRqiYEwub2A?pwd=dcr3)
- 模型下载之后,放在`rapid_paraformer/models`目录下即可,最终目录结构如下:
- 由于模型太大(823.8M),上传到仓库不容易下载,提供百度云下载连接:[asr_paraformerv2.onnx](https://pan.baidu.com/s/1-nEf2eUpkzlcRqiYEwub2A?pwd=dcr3)(模型MD5: `9ca331381a470bc4458cc6c0b0b165de`
- 模型下载之后,放在`resources/models`目录下即可,最终目录结构如下:
```text
rapid_paraformer
├── config.yaml
├── __init__.py
├── kaldifeat
│   ├── feature.py
.
├── demo.py
├── rapid_paraformer
│   ├── __init__.py
│   ├── ivector.py
│   ├── LICENSE
│   └── README.md
├── models
│   ├── am.mvn
│   ├── asr_paraformerv2.onnx # 放在这里
│   └── token_list.pkl
├── rapid_paraformer.py
└── utils.py
│   ├── kaldifeat
│   ├── __pycache__
│   ├── rapid_paraformer.py
│   └── utils.py
├── README.md
├── requirements.txt
├── resources
│   ├── config.yaml
│   └── models
│   ├── am.mvn
│   ├── asr_paraformerv2.onnx # 放在这里
│   └── token_list.pkl
├── test_onnx.py
├── tests
│   ├── __pycache__
│   └── test_infer.py
└── test_wavs
├── 0478_00017.wav
└── asr_example_zh.wav
```
3. 运行demo
```python
from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer()
wav_path = 'test_wavs/example_test.wav'
result = paraformer(str(wav_path))
config_path = 'resources/config.yaml'
paraformer = RapidParaformer(config_path)
# 输入:支持Union[str, np.ndarray, List[str]] 三种方式传入
# 输出: List[asr_res]
wav_path = [
'test_wavs/0478_00017.wav',
]
result = paraformer(wav_path)
print(result)
```
4. 查看结果
```text
[['呃说不配合就不配合的好以上的话呢我们摘取八九十三条因为这三条的话呢比较典型啊一些数字比较明确尤其是时间那么我们要投资者就是了解这一点啊不要轻信这个市场可以快速回来啊这些配市公司啊后期又利好了可
以快速快速攻能包括像前一段时间啊有些媒体在二三月份的时候']]
['呃说不配合就不配合的好以上的话呢我们摘取八九十三条因为这三条的话呢比较典型啊一些数字比较明确尤其是时间那么我们要投资者就是了解这一点啊不要轻信这个市场可以快速回来啊这些配市公司啊后期又利好了可
以快速快速攻能包括像前一段时间啊有些媒体在二三月份的时候']
```
......@@ -4,9 +4,20 @@
from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer()
config_path = 'resources/config.yaml'
paraformer = RapidParaformer(config_path)
wav_path = [
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
]
wav_path = 'test_wavs/0478_00017.wav'
print(wav_path)
result = paraformer(str(wav_path))
# wav_path = 'test_wavs/0478_00017.wav'
result = paraformer(wav_path)
print(result)
......@@ -3,50 +3,109 @@
# @Contact: liekkaskono@163.com
import traceback
from pathlib import Path
from typing import List
from typing import List, Union, Tuple
import librosa
import numpy as np
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession,
TokenIDConverter, WavFrontend, read_yaml, get_logger)
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
read_yaml)
cur_dir = Path(__file__).resolve().parent
logging = get_logger()
class RapidParaformer():
def __init__(self, config_path: str = None) -> None:
config = read_yaml(cur_dir / 'config.yaml')
if config_path:
config = read_yaml(config_path)
def __init__(self, config_path: Union[str, Path]) -> None:
if not Path(config_path).exists():
raise FileNotFoundError(f'{config_path} does not exist.')
config = read_yaml(config_path)
self.converter = TokenIDConverter(**config['TokenIDConverter'])
self.tokenizer = CharTokenizer(**config['CharTokenizer'])
self.frontend_asr = WavFrontend(
self.frontend = WavFrontend(
cmvn_file=config['WavFrontend']['cmvn_file'],
**config['WavFrontend']['frontend_conf']
)
self.ort_infer = OrtInferSession(config['Model'])
def __call__(self, wav_path: str) -> List:
waveform = librosa.load(wav_path)[0][None, ...]
speech, _ = self.frontend_asr.forward_fbank(waveform)
feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech)
try:
am_scores = self.ort_infer(input_content=[feats, feats_len])
except ONNXRuntimeError:
logging.error(traceback.format_exc())
return []
results = []
for am_score in am_scores:
pred_res = self.infer_one_feat(am_score)
results.append(pred_res)
return results
def infer_one_feat(self, am_score: np.ndarray) -> List[str]:
self.batch_size = config['Model']['batch_size']
def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
waveform_list = self.load_data(wav_content)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
am_scores, valid_token_lens = self.infer(feats, feats_len)
except ONNXRuntimeError:
logging.error(traceback.format_exc())
preds = []
else:
preds = self.decode(am_scores, valid_token_lens)
asr_res.extend(preds)
return asr_res
def load_data(self,
wav_content: Union[str, np.ndarray, List[str]]) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path)
return waveform[None, ...]
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(
f'The type of {wav_content} is not in [str, np.ndarray, list]')
def extract_feat(self,
waveform_list: List[np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, 'constant', constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: np.ndarray,
feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
am_scores, token_nums = self.ort_infer([feats, feats_len])
return am_scores, token_nums
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)]
def decode_one(self,
am_score: np.ndarray,
valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
......@@ -54,27 +113,25 @@ class RapidParaformer():
# pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2])
nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
infer_res = []
for hyp in nbest_hyps:
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
hyp = Hypothesis(yseq=yseq, score=score)
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x not in (0, 2), token_int))
text = self.tokenizer.tokens2text(token)
infer_res.append(text)
return infer_res
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
text = self.tokenizer.tokens2text(token)
return text[:valid_token_num-1]
if __name__ == '__main__':
paraformer = RapidParaformer()
project_dir = Path(__file__).resolve().parent.parent
cfg_path = project_dir / 'resources' / 'config.yaml'
paraformer = RapidParaformer(cfg_path)
wav_file = '0478_00017.wav'
for i in range(1000):
......
......@@ -25,7 +25,7 @@ class TokenIDConverter():
unk_symbol: str = "<unk>",):
check_argument_types()
self.token_list = self.load_token(root_dir / token_path)
self.token_list = self.load_token(token_path)
self.unk_symbol = unk_symbol
@staticmethod
......@@ -148,58 +148,37 @@ class WavFrontend():
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = root_dir / cmvn_file
self.cmvn_file = cmvn_file
self.dither = dither
def forward_fbank(self,
input_content: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_lens = [], []
batch_size = input_content.shape[0]
input_lengths = np.array([input_content.shape[1]])
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input_content[i][:waveform_length]
waveform = waveform * (1 << 15)
mat = compute_fbank_feats(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
sample_frequency=self.fs)
feats.append(mat)
feats_lens.append(mat.shape[0])
feats_pad = np.array(feats).astype(np.float32)
feats_lens = np.array(feats_lens).astype(np.int64)
return feats_pad, feats_lens
def forward_lfr_cmvn(self,
input_content: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_lens = [], []
batch_size = input_content.shape[0]
if self.cmvn_file:
cmvn = self.load_cmvn()
input_lengths = np.array([input_content.shape[1]])
for i in range(batch_size):
mat = input_content[i, :input_lengths[i], :]
if self.lfr_m != 1 or self.lfr_n != 1:
mat = self.apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn_file:
mat = self.apply_cmvn(mat, cmvn)
self.cmvn = self.load_cmvn()
def fbank(self,
input_content: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform_len = input_content.shape[1]
waveform = input_content[0][:waveform_len]
waveform = waveform * (1 << 15)
mat = compute_fbank_feats(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
sample_frequency=self.fs)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.lfr_m != 1 or self.lfr_n != 1:
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
feats.append(mat)
feats_lens.append(mat.shape[0])
if self.cmvn_file:
feat = self.apply_cmvn(feat)
feats_pad = np.array(feats).astype(np.float32)
feats_lens = np.array(feats_lens).astype(np.int32)
return feats_pad, feats_lens
feat_len = np.array(feat.shape[0]).astype(np.int32)
return feat, feat_len
@staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
......@@ -225,13 +204,13 @@ class WavFrontend():
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
return LFR_outputs
def apply_cmvn(self, inputs: np.ndarray, cmvn: np.ndarray) -> np.ndarray:
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
"""
Apply CMVN with mvn data
"""
frame, dim = inputs.shape
means = np.tile(cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1))
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
inputs = (inputs + means) * vars
return inputs
......@@ -306,7 +285,7 @@ class OrtInferSession():
EP_list = [(cuda_ep, config[cuda_ep])]
EP_list.append((cpu_ep, cpu_provider_options))
config['model_path'] = str(root_dir / config['model_path'])
config['model_path'] = config['model_path']
self._verify_model(config['model_path'])
self.session = InferenceSession(config['model_path'],
sess_options=sess_opt,
......@@ -323,7 +302,7 @@ class OrtInferSession():
input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
return self.session.run(None, input_dict)[0]
return self.session.run(None, input_dict)
except Exception as e:
raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
......@@ -361,20 +340,14 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
@functools.lru_cache()
def get_logger(name='xxx',
log_file=root_dir.joinpath('error.log')):
def get_logger(name='rapdi_paraformer'):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
......@@ -390,15 +363,9 @@ def get_logger(name='xxx',
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
if log_file:
log_file_folder = Path(log_file).parent
log_file_folder.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.setLevel(logging.ERROR)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
logger_initialized[name] = True
logger.propagate = False
return logger
TokenIDConverter:
token_path: models/token_list.pkl
token_path: resources/models/token_list.pkl
unk_symbol: <unk>
CharTokenizer:
......@@ -8,7 +8,7 @@ CharTokenizer:
remove_non_linguistic_symbols: false
WavFrontend:
cmvn_file: models/am.mvn
cmvn_file: resources/models/am.mvn
frontend_conf:
fs: 16000
window: hamming
......@@ -20,10 +20,11 @@ WavFrontend:
filter_length_max: -.inf
Model:
model_path: models/asr_paraformerv2.onnx
model_path: resources/models/model.onnx
use_cuda: false
CUDAExecutionProvider:
device_id: 0
arena_extend_strategy: kNextPowerOfTwo
cudnn_conv_algo_search: EXHAUSTIVE
do_copy_in_default_stream: true
\ No newline at end of file
do_copy_in_default_stream: true
batch_size: 3
\ No newline at end of file
......@@ -4,14 +4,44 @@
import os
from pathlib import Path
os.sys.path.append(str(Path(__file__).resolve().parent.parent))
import pytest
import librosa
project_dir = Path(__file__).resolve().parent.parent
os.sys.path.append(str(project_dir))
from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer()
cfg_path = project_dir / 'resources' / 'config.yaml'
paraformer = RapidParaformer(cfg_path)
def test_normal():
def test_input_by_path():
wav_file = 'test_wavs/0478_00017.wav'
result = paraformer(wav_file)
assert result[0][0][:5] == '呃说不配合'
assert result[0][:5] == '呃说不配合'
def test_input_by_ndarray():
wav_file = 'test_wavs/0478_00017.wav'
waveform, _ = librosa.load(wav_file)
result = paraformer(waveform[None, ...])
assert result[0][:5] == '呃说不配合'
def test_input_by_str_list():
wave_list = [
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
]
result = paraformer(wave_list)
assert result[0][:5] == '呃说不配合'
def test_empty():
wav_file = None
with pytest.raises(TypeError) as exc_info:
paraformer(wav_file)
raise TypeError()
assert exc_info.type is TypeError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment