Commit 392da8a4 authored by SWHL's avatar SWHL
Browse files

Add test code

parent 159403db
Pipeline #334 failed with stages
in 0 seconds
......@@ -9,7 +9,8 @@ import librosa
import numpy as np
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession,
TokenIDConverter, WavFrontend, read_yaml, get_logger)
TokenIDConverter, WavFrontend, read_yaml, get_logger,
OpenVINOInferSession)
cur_dir = Path(__file__).resolve().parent
logging = get_logger()
......@@ -28,6 +29,7 @@ class RapidParaformer():
**config['WavFrontend']['frontend_conf']
)
self.ort_infer = OrtInferSession(config['Model'])
self.vino_infer = OpenVINOInferSession(config['Model'])
def __call__(self, wav_path: str) -> List:
waveform = librosa.load(wav_path)[0][None, ...]
......@@ -35,7 +37,8 @@ class RapidParaformer():
speech, _ = self.frontend_asr.forward_fbank(waveform)
feats, feats_len = self.frontend_asr.forward_lfr_cmvn(speech)
try:
am_scores = self.ort_infer(input_content=[feats, feats_len])
# am_scores = self.ort_infer(input_content=[feats, feats_len])
am_scores = self.vino_infer(input_content=[feats, feats_len])
except ONNXRuntimeError:
logging.error(traceback.format_exc())
return []
......
......@@ -11,6 +11,7 @@ import numpy as np
import yaml
from onnxruntime import (GraphOptimizationLevel, InferenceSession,
SessionOptions, get_available_providers, get_device)
from openvino.runtime import Core
from typeguard import check_argument_types
from .kaldifeat import compute_fbank_feats
......@@ -351,6 +352,29 @@ class OrtInferSession():
raise FileExistsError(f'{model_path} is not a file.')
class OpenVINOInferSession():
def __init__(self, config):
ie = Core()
config['model_path'] = str(root_dir / config['model_path'])
self._verify_model(config['model_path'])
model_onnx = ie.read_model(config['model_path'])
compile_model = ie.compile_model(model=model_onnx, device_name='CPU')
self.session = compile_model.create_infer_request()
def __call__(self, input_content: np.ndarray) -> np.ndarray:
self.session.infer(inputs=[input_content])
return self.session.get_output_tensor().data
@staticmethod
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f'{model_path} does not exists.')
if not model_path.is_file():
raise FileExistsError(f'{model_path} is not a file.')
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
raise FileExistsError(f'The {yaml_path} does not exist.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment