Unverified Commit b40649b6 authored by SWHL's avatar SWHL Committed by GitHub
Browse files

Merge pull request #4 from RapidAI/rapid_paraformer

Add to catch the error of inferring wav.
parents e144495d 6ffe3be3
## Rapid paraformer ## Rapid paraformer
<p align="left">
<a href=""><img src="https://img.shields.io/badge/Python->=3.7,<=3.10-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
</p>
- 模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) - 模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
- 本分支对模型做了转换,仅采用ONNXRuntime推理引擎 - 本仓库仅对模型做了转换,只采用ONNXRuntime推理引擎
#### 更新日志
- 2023-02-10 v2.0.1 update:
- 添加对输入音频为噪音或者静音的文件推理结果捕捉
#### 使用步骤 #### 使用步骤
......
...@@ -6,7 +6,7 @@ from rapid_paraformer import RapidParaformer ...@@ -6,7 +6,7 @@ from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer() paraformer = RapidParaformer()
wav_path = 'test_wavs/example_test.wav' wav_path = 'test_wavs/1657082555863994221829340016640.wav'
print(wav_path) print(wav_path)
result = paraformer(str(wav_path)) result = paraformer(str(wav_path))
print(result) print(result)
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
# @Author: SWHL # @Author: SWHL
# @Contact: liekkaskono@163.com # @Contact: liekkaskono@163.com
import traceback
from pathlib import Path from pathlib import Path
from typing import List from typing import List
import librosa import librosa
import numpy as np import numpy as np
from .utils import (CharTokenizer, Hypothesis, OrtInferSession, from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession,
TokenIDConverter, WavFrontend, read_yaml) TokenIDConverter, WavFrontend, read_yaml, get_logger)
cur_dir = Path(__file__).resolve().parent cur_dir = Path(__file__).resolve().parent
logging = get_logger()
class RapidParaformer(): class RapidParaformer():
...@@ -32,7 +34,11 @@ class RapidParaformer(): ...@@ -32,7 +34,11 @@ class RapidParaformer():
speech, _ = self.frontend_asr.forward_fbank(waveform) speech, _ = self.frontend_asr.forward_fbank(waveform)
feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech) feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech)
am_scores = self.ort_infer(input_content=[feats, feats_len]) try:
am_scores = self.ort_infer(input_content=[feats, feats_len])
except ONNXRuntimeError:
logging.error(traceback.format_exc())
return []
results = [] results = []
for am_score in am_scores: for am_score in am_scores:
......
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
# @Author: SWHL # @Author: SWHL
# @Contact: liekkaskono@163.com # @Contact: liekkaskono@163.com
import functools
import logging import logging
import pickle import pickle
from pathlib import Path from pathlib import Path
...@@ -16,6 +17,8 @@ from .kaldifeat import compute_fbank_feats ...@@ -16,6 +17,8 @@ from .kaldifeat import compute_fbank_feats
root_dir = Path(__file__).resolve().parent root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class TokenIDConverter(): class TokenIDConverter():
def __init__(self, token_path: Union[Path, str], def __init__(self, token_path: Union[Path, str],
...@@ -280,6 +283,10 @@ class TokenIDConverterError(Exception): ...@@ -280,6 +283,10 @@ class TokenIDConverterError(Exception):
pass pass
class ONNXRuntimeError(Exception):
pass
class OrtInferSession(): class OrtInferSession():
def __init__(self, config): def __init__(self, config):
sess_opt = SessionOptions() sess_opt = SessionOptions()
...@@ -315,7 +322,10 @@ class OrtInferSession(): ...@@ -315,7 +322,10 @@ class OrtInferSession():
def __call__(self, def __call__(self,
input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content)) input_dict = dict(zip(self.get_input_names(), input_content))
return self.session.run(None, input_dict)[0] try:
return self.session.run(None, input_dict)[0]
except Exception as e:
raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
def get_input_names(self, ): def get_input_names(self, ):
return [v.name for v in self.session.get_inputs()] return [v.name for v in self.session.get_inputs()]
...@@ -348,3 +358,47 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict: ...@@ -348,3 +358,47 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
with open(str(yaml_path), 'rb') as f: with open(str(yaml_path), 'rb') as f:
data = yaml.load(f, Loader=yaml.Loader) data = yaml.load(f, Loader=yaml.Loader)
return data return data
@functools.lru_cache()
def get_logger(name='xxx',
log_file=root_dir.joinpath('error.log')):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
if log_file:
log_file_folder = Path(log_file).parent
log_file_folder.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
logger.propagate = False
return logger
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment