Commit 7d0435cd authored by SWHL's avatar SWHL
Browse files

Update to v2.0.2

parent 159403db
...@@ -10,6 +10,11 @@ ...@@ -10,6 +10,11 @@
#### 更新日志 #### 更新日志
- 2023-02-11 v2.0.2 update:
- 模型和推理代码解耦(`rapid_paraformer``resources`
- 支持批量推理(通过`resources/config.yaml``batch_size`指定)
- 增加多种输入方式(`Union[str, np.ndarray, List[str]]`
- 2023-02-10 v2.0.1 update: - 2023-02-10 v2.0.1 update:
- 添加对输入音频为噪音或者静音的文件推理结果捕捉。 - 添加对输入音频为噪音或者静音的文件推理结果捕捉。
...@@ -20,38 +25,53 @@ ...@@ -20,38 +25,53 @@
pip install -r requirements.txt pip install -r requirements.txt
``` ```
2. 下载模型 2. 下载模型
- 由于模型太大(881M),上传到仓库不容易下载,提供百度云下载连接:[asr_paraformerv2.onnx](https://pan.baidu.com/s/1-nEf2eUpkzlcRqiYEwub2A?pwd=dcr3) - 由于模型太大(823.8M),上传到仓库不容易下载,提供百度云下载连接:[asr_paraformerv2.onnx](https://pan.baidu.com/s/1-nEf2eUpkzlcRqiYEwub2A?pwd=dcr3)(模型MD5: `9ca331381a470bc4458cc6c0b0b165de`
- 模型下载之后,放在`rapid_paraformer/models`目录下即可,最终目录结构如下: - 模型下载之后,放在`resources/models`目录下即可,最终目录结构如下:
```text ```text
rapid_paraformer .
├── config.yaml ├── demo.py
├── __init__.py ├── rapid_paraformer
├── kaldifeat
│   ├── feature.py
│   ├── __init__.py │   ├── __init__.py
│   ├── ivector.py │   ├── kaldifeat
│   ├── LICENSE │   ├── __pycache__
│   └── README.md │   ├── rapid_paraformer.py
├── models │   └── utils.py
│   ├── am.mvn ├── README.md
│   ├── asr_paraformerv2.onnx # 放在这里 ├── requirements.txt
│   └── token_list.pkl ├── resources
├── rapid_paraformer.py │   ├── config.yaml
└── utils.py │   └── models
│   ├── am.mvn
│   ├── asr_paraformerv2.onnx # 放在这里
│   └── token_list.pkl
├── test_onnx.py
├── tests
│   ├── __pycache__
│   └── test_infer.py
└── test_wavs
├── 0478_00017.wav
└── asr_example_zh.wav
``` ```
3. 运行demo 3. 运行demo
```python ```python
from rapid_paraformer import RapidParaformer from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer()
wav_path = 'test_wavs/example_test.wav' config_path = 'resources/config.yaml'
result = paraformer(str(wav_path)) paraformer = RapidParaformer(config_path)
# 输入:支持Union[str, np.ndarray, List[str]] 三种方式传入
# 输出: List[asr_res]
wav_path = [
'test_wavs/0478_00017.wav',
]
result = paraformer(wav_path)
print(result) print(result)
``` ```
4. 查看结果 4. 查看结果
```text ```text
[['呃说不配合就不配合的好以上的话呢我们摘取八九十三条因为这三条的话呢比较典型啊一些数字比较明确尤其是时间那么我们要投资者就是了解这一点啊不要轻信这个市场可以快速回来啊这些配市公司啊后期又利好了可 ['呃说不配合就不配合的好以上的话呢我们摘取八九十三条因为这三条的话呢比较典型啊一些数字比较明确尤其是时间那么我们要投资者就是了解这一点啊不要轻信这个市场可以快速回来啊这些配市公司啊后期又利好了可
以快速快速攻能包括像前一段时间啊有些媒体在二三月份的时候']] 以快速快速攻能包括像前一段时间啊有些媒体在二三月份的时候']
``` ```
...@@ -4,9 +4,20 @@ ...@@ -4,9 +4,20 @@
from rapid_paraformer import RapidParaformer from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer() config_path = 'resources/config.yaml'
paraformer = RapidParaformer(config_path)
wav_path = [
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
]
wav_path = 'test_wavs/0478_00017.wav'
print(wav_path) print(wav_path)
result = paraformer(str(wav_path)) # wav_path = 'test_wavs/0478_00017.wav'
result = paraformer(wav_path)
print(result) print(result)
...@@ -3,50 +3,109 @@ ...@@ -3,50 +3,109 @@
# @Contact: liekkaskono@163.com # @Contact: liekkaskono@163.com
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import List from typing import List, Union, Tuple
import librosa import librosa
import numpy as np import numpy as np
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession, from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
TokenIDConverter, WavFrontend, read_yaml, get_logger) OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
read_yaml)
cur_dir = Path(__file__).resolve().parent
logging = get_logger() logging = get_logger()
class RapidParaformer(): class RapidParaformer():
def __init__(self, config_path: str = None) -> None: def __init__(self, config_path: Union[str, Path]) -> None:
config = read_yaml(cur_dir / 'config.yaml') if not Path(config_path).exists():
if config_path: raise FileNotFoundError(f'{config_path} does not exist.')
config = read_yaml(config_path)
config = read_yaml(config_path)
self.converter = TokenIDConverter(**config['TokenIDConverter']) self.converter = TokenIDConverter(**config['TokenIDConverter'])
self.tokenizer = CharTokenizer(**config['CharTokenizer']) self.tokenizer = CharTokenizer(**config['CharTokenizer'])
self.frontend_asr = WavFrontend( self.frontend = WavFrontend(
cmvn_file=config['WavFrontend']['cmvn_file'], cmvn_file=config['WavFrontend']['cmvn_file'],
**config['WavFrontend']['frontend_conf'] **config['WavFrontend']['frontend_conf']
) )
self.ort_infer = OrtInferSession(config['Model']) self.ort_infer = OrtInferSession(config['Model'])
self.batch_size = config['Model']['batch_size']
def __call__(self, wav_path: str) -> List:
waveform = librosa.load(wav_path)[0][None, ...] def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
waveform_list = self.load_data(wav_content)
speech, _ = self.frontend_asr.forward_fbank(waveform) waveform_nums = len(waveform_list)
feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech)
try: asr_res = []
am_scores = self.ort_infer(input_content=[feats, feats_len]) for beg_idx in range(0, waveform_nums, self.batch_size):
except ONNXRuntimeError: end_idx = min(waveform_nums, beg_idx + self.batch_size)
logging.error(traceback.format_exc())
return [] feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
results = [] try:
for am_score in am_scores: am_scores, valid_token_lens = self.infer(feats, feats_len)
pred_res = self.infer_one_feat(am_score) except ONNXRuntimeError:
results.append(pred_res) logging.error(traceback.format_exc())
return results preds = []
else:
def infer_one_feat(self, am_score: np.ndarray) -> List[str]: preds = self.decode(am_scores, valid_token_lens)
asr_res.extend(preds)
return asr_res
def load_data(self,
wav_content: Union[str, np.ndarray, List[str]]) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path)
return waveform[None, ...]
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(
f'The type of {wav_content} is not in [str, np.ndarray, list]')
def extract_feat(self,
waveform_list: List[np.ndarray]
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, 'constant', constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self, feats: np.ndarray,
feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
am_scores, token_nums = self.ort_infer([feats, feats_len])
return am_scores, token_nums
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
return [self.decode_one(am_score, token_num)
for am_score, token_num in zip(am_scores, token_nums)]
def decode_one(self,
am_score: np.ndarray,
valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1) yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1) score = am_score.max(axis=-1)
score = np.sum(score, axis=-1) score = np.sum(score, axis=-1)
...@@ -54,27 +113,25 @@ class RapidParaformer(): ...@@ -54,27 +113,25 @@ class RapidParaformer():
# pad with mask tokens to ensure compatibility with sos/eos tokens # pad with mask tokens to ensure compatibility with sos/eos tokens
# asr_model.sos:1 asr_model.eos:2 # asr_model.sos:1 asr_model.eos:2
yseq = np.array([1] + yseq.tolist() + [2]) yseq = np.array([1] + yseq.tolist() + [2])
nbest_hyps = [Hypothesis(yseq=yseq, score=score)] hyp = Hypothesis(yseq=yseq, score=score)
infer_res = []
for hyp in nbest_hyps:
# remove sos/eos and get results
last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# remove blank symbol id, which is assumed to be 0 # remove sos/eos and get results
token_int = list(filter(lambda x: x not in (0, 2), token_int)) last_pos = -1
token_int = hyp.yseq[1:last_pos].tolist()
# Change integer-ids to tokens # remove blank symbol id, which is assumed to be 0
token = self.converter.ids2tokens(token_int) token_int = list(filter(lambda x: x not in (0, 2), token_int))
text = self.tokenizer.tokens2text(token) # Change integer-ids to tokens
infer_res.append(text) token = self.converter.ids2tokens(token_int)
return infer_res text = self.tokenizer.tokens2text(token)
return text[:valid_token_num-1]
if __name__ == '__main__': if __name__ == '__main__':
paraformer = RapidParaformer() project_dir = Path(__file__).resolve().parent.parent
cfg_path = project_dir / 'resources' / 'config.yaml'
paraformer = RapidParaformer(cfg_path)
wav_file = '0478_00017.wav' wav_file = '0478_00017.wav'
for i in range(1000): for i in range(1000):
......
...@@ -25,7 +25,7 @@ class TokenIDConverter(): ...@@ -25,7 +25,7 @@ class TokenIDConverter():
unk_symbol: str = "<unk>",): unk_symbol: str = "<unk>",):
check_argument_types() check_argument_types()
self.token_list = self.load_token(root_dir / token_path) self.token_list = self.load_token(token_path)
self.unk_symbol = unk_symbol self.unk_symbol = unk_symbol
@staticmethod @staticmethod
...@@ -148,58 +148,37 @@ class WavFrontend(): ...@@ -148,58 +148,37 @@ class WavFrontend():
self.filter_length_max = filter_length_max self.filter_length_max = filter_length_max
self.lfr_m = lfr_m self.lfr_m = lfr_m
self.lfr_n = lfr_n self.lfr_n = lfr_n
self.cmvn_file = root_dir / cmvn_file self.cmvn_file = cmvn_file
self.dither = dither self.dither = dither
def forward_fbank(self,
input_content: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_lens = [], []
batch_size = input_content.shape[0]
input_lengths = np.array([input_content.shape[1]])
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input_content[i][:waveform_length]
waveform = waveform * (1 << 15)
mat = compute_fbank_feats(waveform,
num_mel_bins=self.n_mels,
frame_length=self.frame_length,
frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
sample_frequency=self.fs)
feats.append(mat)
feats_lens.append(mat.shape[0])
feats_pad = np.array(feats).astype(np.float32)
feats_lens = np.array(feats_lens).astype(np.int64)
return feats_pad, feats_lens
def forward_lfr_cmvn(self,
input_content: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_lens = [], []
batch_size = input_content.shape[0]
if self.cmvn_file: if self.cmvn_file:
cmvn = self.load_cmvn() self.cmvn = self.load_cmvn()
input_lengths = np.array([input_content.shape[1]]) def fbank(self,
for i in range(batch_size): input_content: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
mat = input_content[i, :input_lengths[i], :] waveform_len = input_content.shape[1]
waveform = input_content[0][:waveform_len]
if self.lfr_m != 1 or self.lfr_n != 1: waveform = waveform * (1 << 15)
mat = self.apply_lfr(mat, self.lfr_m, self.lfr_n) mat = compute_fbank_feats(waveform,
num_mel_bins=self.n_mels,
if self.cmvn_file: frame_length=self.frame_length,
mat = self.apply_cmvn(mat, cmvn) frame_shift=self.frame_shift,
dither=self.dither,
energy_floor=0.0,
sample_frequency=self.fs)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.lfr_m != 1 or self.lfr_n != 1:
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
feats.append(mat) if self.cmvn_file:
feats_lens.append(mat.shape[0]) feat = self.apply_cmvn(feat)
feats_pad = np.array(feats).astype(np.float32) feat_len = np.array(feat.shape[0]).astype(np.int32)
feats_lens = np.array(feats_lens).astype(np.int32) return feat, feat_len
return feats_pad, feats_lens
@staticmethod @staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray: def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
...@@ -225,13 +204,13 @@ class WavFrontend(): ...@@ -225,13 +204,13 @@ class WavFrontend():
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32) LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
return LFR_outputs return LFR_outputs
def apply_cmvn(self, inputs: np.ndarray, cmvn: np.ndarray) -> np.ndarray: def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
""" """
Apply CMVN with mvn data Apply CMVN with mvn data
""" """
frame, dim = inputs.shape frame, dim = inputs.shape
means = np.tile(cmvn[0:1, :dim], (frame, 1)) means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
vars = np.tile(cmvn[1:2, :dim], (frame, 1)) vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
inputs = (inputs + means) * vars inputs = (inputs + means) * vars
return inputs return inputs
...@@ -306,7 +285,7 @@ class OrtInferSession(): ...@@ -306,7 +285,7 @@ class OrtInferSession():
EP_list = [(cuda_ep, config[cuda_ep])] EP_list = [(cuda_ep, config[cuda_ep])]
EP_list.append((cpu_ep, cpu_provider_options)) EP_list.append((cpu_ep, cpu_provider_options))
config['model_path'] = str(root_dir / config['model_path']) config['model_path'] = config['model_path']
self._verify_model(config['model_path']) self._verify_model(config['model_path'])
self.session = InferenceSession(config['model_path'], self.session = InferenceSession(config['model_path'],
sess_options=sess_opt, sess_options=sess_opt,
...@@ -323,7 +302,7 @@ class OrtInferSession(): ...@@ -323,7 +302,7 @@ class OrtInferSession():
input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content)) input_dict = dict(zip(self.get_input_names(), input_content))
try: try:
return self.session.run(None, input_dict)[0] return self.session.run(None, input_dict)
except Exception as e: except Exception as e:
raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
...@@ -361,20 +340,14 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict: ...@@ -361,20 +340,14 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
@functools.lru_cache() @functools.lru_cache()
def get_logger(name='xxx', def get_logger(name='rapdi_paraformer'):
log_file=root_dir.joinpath('error.log')):
"""Initialize and get a logger by name. """Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added. added.
Args: Args:
name (str): Logger name. name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns: Returns:
logging.Logger: The expected logger. logging.Logger: The expected logger.
""" """
...@@ -390,15 +363,9 @@ def get_logger(name='xxx', ...@@ -390,15 +363,9 @@ def get_logger(name='xxx',
'[%(asctime)s] %(name)s %(levelname)s: %(message)s', '[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S") datefmt="%Y/%m/%d %H:%M:%S")
if log_file: sh = logging.StreamHandler()
log_file_folder = Path(log_file).parent sh.setFormatter(formatter)
log_file_folder.mkdir(parents=True, exist_ok=True) logger.addHandler(sh)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.setLevel(logging.ERROR)
logger_initialized[name] = True logger_initialized[name] = True
logger.propagate = False logger.propagate = False
return logger return logger
TokenIDConverter: TokenIDConverter:
token_path: models/token_list.pkl token_path: resources/models/token_list.pkl
unk_symbol: <unk> unk_symbol: <unk>
CharTokenizer: CharTokenizer:
...@@ -8,7 +8,7 @@ CharTokenizer: ...@@ -8,7 +8,7 @@ CharTokenizer:
remove_non_linguistic_symbols: false remove_non_linguistic_symbols: false
WavFrontend: WavFrontend:
cmvn_file: models/am.mvn cmvn_file: resources/models/am.mvn
frontend_conf: frontend_conf:
fs: 16000 fs: 16000
window: hamming window: hamming
...@@ -20,10 +20,11 @@ WavFrontend: ...@@ -20,10 +20,11 @@ WavFrontend:
filter_length_max: -.inf filter_length_max: -.inf
Model: Model:
model_path: models/asr_paraformerv2.onnx model_path: resources/models/model.onnx
use_cuda: false use_cuda: false
CUDAExecutionProvider: CUDAExecutionProvider:
device_id: 0 device_id: 0
arena_extend_strategy: kNextPowerOfTwo arena_extend_strategy: kNextPowerOfTwo
cudnn_conv_algo_search: EXHAUSTIVE cudnn_conv_algo_search: EXHAUSTIVE
do_copy_in_default_stream: true do_copy_in_default_stream: true
\ No newline at end of file batch_size: 3
\ No newline at end of file
...@@ -4,14 +4,44 @@ ...@@ -4,14 +4,44 @@
import os import os
from pathlib import Path from pathlib import Path
os.sys.path.append(str(Path(__file__).resolve().parent.parent)) import pytest
import librosa
project_dir = Path(__file__).resolve().parent.parent
os.sys.path.append(str(project_dir))
from rapid_paraformer import RapidParaformer from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer()
cfg_path = project_dir / 'resources' / 'config.yaml'
paraformer = RapidParaformer(cfg_path)
def test_normal(): def test_input_by_path():
wav_file = 'test_wavs/0478_00017.wav' wav_file = 'test_wavs/0478_00017.wav'
result = paraformer(wav_file) result = paraformer(wav_file)
assert result[0][0][:5] == '呃说不配合' assert result[0][:5] == '呃说不配合'
def test_input_by_ndarray():
wav_file = 'test_wavs/0478_00017.wav'
waveform, _ = librosa.load(wav_file)
result = paraformer(waveform[None, ...])
assert result[0][:5] == '呃说不配合'
def test_input_by_str_list():
wave_list = [
'test_wavs/0478_00017.wav',
'test_wavs/asr_example_zh.wav',
]
result = paraformer(wave_list)
assert result[0][:5] == '呃说不配合'
def test_empty():
wav_file = None
with pytest.raises(TypeError) as exc_info:
paraformer(wav_file)
raise TypeError()
assert exc_info.type is TypeError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment