"examples/vscode:/vscode.git/clone" did not exist on "254cad2dc89f9fe5cf61e27e7121b15838e0ad84"
Unverified Commit b40649b6 authored by SWHL's avatar SWHL Committed by GitHub
Browse files

Merge pull request #4 from RapidAI/rapid_paraformer

Add to catch the error of inferring wav.
parents e144495d 6ffe3be3
## Rapid paraformer
<p align="left">
<a href=""><img src="https://img.shields.io/badge/Python->=3.7,<=3.10-aff.svg"></a>
<a href=""><img src="https://img.shields.io/badge/OS-Linux%2C%20Win%2C%20Mac-pink.svg"></a>
</p>
- 模型出自阿里达摩院[Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
- 本分支对模型做了转换,仅采用ONNXRuntime推理引擎
- 本仓库仅对模型做了转换,只采用ONNXRuntime推理引擎
#### 更新日志
- 2023-02-10 v2.0.1 update:
- 添加对输入音频为噪音或者静音的文件推理结果捕捉
#### 使用步骤
......
......@@ -6,7 +6,7 @@ from rapid_paraformer import RapidParaformer
paraformer = RapidParaformer()
wav_path = 'test_wavs/example_test.wav'
wav_path = 'test_wavs/1657082555863994221829340016640.wav'
print(wav_path)
result = paraformer(str(wav_path))
print(result)
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import traceback
from pathlib import Path
from typing import List
import librosa
import numpy as np
from .utils import (CharTokenizer, Hypothesis, OrtInferSession,
TokenIDConverter, WavFrontend, read_yaml)
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession,
TokenIDConverter, WavFrontend, read_yaml, get_logger)
cur_dir = Path(__file__).resolve().parent
logging = get_logger()
class RapidParaformer():
......@@ -32,7 +34,11 @@ class RapidParaformer():
speech, _ = self.frontend_asr.forward_fbank(waveform)
feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech)
am_scores = self.ort_infer(input_content=[feats, feats_len])
try:
am_scores = self.ort_infer(input_content=[feats, feats_len])
except ONNXRuntimeError:
logging.error(traceback.format_exc())
return []
results = []
for am_score in am_scores:
......
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import functools
import logging
import pickle
from pathlib import Path
......@@ -16,6 +17,8 @@ from .kaldifeat import compute_fbank_feats
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class TokenIDConverter():
def __init__(self, token_path: Union[Path, str],
......@@ -280,6 +283,10 @@ class TokenIDConverterError(Exception):
pass
class ONNXRuntimeError(Exception):
pass
class OrtInferSession():
def __init__(self, config):
sess_opt = SessionOptions()
......@@ -315,7 +322,10 @@ class OrtInferSession():
def __call__(self,
input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
return self.session.run(None, input_dict)[0]
try:
return self.session.run(None, input_dict)[0]
except Exception as e:
raise ONNXRuntimeError('ONNXRuntime inferece failed.') from e
def get_input_names(self, ):
return [v.name for v in self.session.get_inputs()]
......@@ -348,3 +358,47 @@ def read_yaml(yaml_path: Union[str, Path]) -> Dict:
with open(str(yaml_path), 'rb') as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
@functools.lru_cache()
def get_logger(name='xxx',
log_file=root_dir.joinpath('error.log')):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
if log_file:
log_file_folder = Path(log_file).parent
log_file_folder.mkdir(parents=True, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.setLevel(logging.ERROR)
logger_initialized[name] = True
logger.propagate = False
return logger
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment