Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
RapidASR
Commits
392da8a4
Commit
392da8a4
authored
Feb 11, 2023
by
SWHL
Browse files
Add test code
parent
159403db
Pipeline
#334
failed with stages
in 0 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
2 deletions
+29
-2
rapid_paraformer/rapid_paraformer.py
rapid_paraformer/rapid_paraformer.py
+5
-2
rapid_paraformer/utils.py
rapid_paraformer/utils.py
+24
-0
No files found.
rapid_paraformer/rapid_paraformer.py
View file @
392da8a4
...
...
@@ -9,7 +9,8 @@ import librosa
import
numpy
as
np
from
.utils
import
(
CharTokenizer
,
Hypothesis
,
ONNXRuntimeError
,
OrtInferSession
,
TokenIDConverter
,
WavFrontend
,
read_yaml
,
get_logger
)
TokenIDConverter
,
WavFrontend
,
read_yaml
,
get_logger
,
OpenVINOInferSession
)
cur_dir
=
Path
(
__file__
).
resolve
().
parent
logging
=
get_logger
()
...
...
@@ -28,6 +29,7 @@ class RapidParaformer():
**
config
[
'WavFrontend'
][
'frontend_conf'
]
)
self
.
ort_infer
=
OrtInferSession
(
config
[
'Model'
])
self
.
vino_infer
=
OpenVINOInferSession
(
config
[
'Model'
])
def
__call__
(
self
,
wav_path
:
str
)
->
List
:
waveform
=
librosa
.
load
(
wav_path
)[
0
][
None
,
...]
...
...
@@ -35,7 +37,8 @@ class RapidParaformer():
speech
,
_
=
self
.
frontend_asr
.
forward_fbank
(
waveform
)
feats
,
feats_len
=
self
.
frontend_asr
.
forward_lfr_cmvn
(
speech
)
try
:
am_scores
=
self
.
ort_infer
(
input_content
=
[
feats
,
feats_len
])
# am_scores = self.ort_infer(input_content=[feats, feats_len])
am_scores
=
self
.
vino_infer
(
input_content
=
[
feats
,
feats_len
])
except
ONNXRuntimeError
:
logging
.
error
(
traceback
.
format_exc
())
return
[]
...
...
rapid_paraformer/utils.py
View file @
392da8a4
...
...
@@ -11,6 +11,7 @@ import numpy as np
import
yaml
from
onnxruntime
import
(
GraphOptimizationLevel
,
InferenceSession
,
SessionOptions
,
get_available_providers
,
get_device
)
from
openvino.runtime
import
Core
from
typeguard
import
check_argument_types
from
.kaldifeat
import
compute_fbank_feats
...
...
@@ -351,6 +352,29 @@ class OrtInferSession():
raise
FileExistsError
(
f
'
{
model_path
}
is not a file.'
)
class
OpenVINOInferSession
():
def
__init__
(
self
,
config
):
ie
=
Core
()
config
[
'model_path'
]
=
str
(
root_dir
/
config
[
'model_path'
])
self
.
_verify_model
(
config
[
'model_path'
])
model_onnx
=
ie
.
read_model
(
config
[
'model_path'
])
compile_model
=
ie
.
compile_model
(
model
=
model_onnx
,
device_name
=
'CPU'
)
self
.
session
=
compile_model
.
create_infer_request
()
def
__call__
(
self
,
input_content
:
np
.
ndarray
)
->
np
.
ndarray
:
self
.
session
.
infer
(
inputs
=
[
input_content
])
return
self
.
session
.
get_output_tensor
().
data
@
staticmethod
def
_verify_model
(
model_path
):
model_path
=
Path
(
model_path
)
if
not
model_path
.
exists
():
raise
FileNotFoundError
(
f
'
{
model_path
}
does not exists.'
)
if
not
model_path
.
is_file
():
raise
FileExistsError
(
f
'
{
model_path
}
is not a file.'
)
def
read_yaml
(
yaml_path
:
Union
[
str
,
Path
])
->
Dict
:
if
not
Path
(
yaml_path
).
exists
():
raise
FileExistsError
(
f
'The
{
yaml_path
}
does not exist.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment