Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
wenet_onnxruntime
Commits
3c4ea2c0
Commit
3c4ea2c0
authored
Jun 06, 2025
by
wufan3
Browse files
fix BUG[92942]:Improving wenet inference performance by creating a separate session for each thread
parent
61aeca13
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
4 deletions
+9
-4
Src/bin/decoder_main.cc
Src/bin/decoder_main.cc
+5
-4
Src/decoder/params.h
Src/decoder/params.h
+4
-0
No files found.
Src/bin/decoder_main.cc
View file @
3c4ea2c0
...
@@ -17,7 +17,7 @@ DEFINE_string(wav_scp, "", "input wav scp");
...
@@ -17,7 +17,7 @@ DEFINE_string(wav_scp, "", "input wav scp");
DEFINE_string
(
result
,
"./result"
,
"result output file"
);
DEFINE_string
(
result
,
"./result"
,
"result output file"
);
DEFINE_bool
(
continuous_decoding
,
false
,
"continuous decoding mode"
);
DEFINE_bool
(
continuous_decoding
,
false
,
"continuous decoding mode"
);
DEFINE_int32
(
thread_num
,
1
,
"num of decode thread"
);
DEFINE_int32
(
thread_num
,
1
,
"num of decode thread"
);
DEFINE_int32
(
warmup
,
0
,
"num of warmup decode, 0 means no warmup"
);
DEFINE_int32
(
warmup
,
1
,
"num of warmup decode, 0 means no warmup"
);
// std::shared_ptr<wenet::DecodeOptions> g_decode_config;
// std::shared_ptr<wenet::DecodeOptions> g_decode_config;
// std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config;
// std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config;
...
@@ -28,7 +28,8 @@ std::mutex g_mutex;
...
@@ -28,7 +28,8 @@ std::mutex g_mutex;
int
g_total_waves_dur
=
0
;
int
g_total_waves_dur
=
0
;
int
g_total_decode_time
=
0
;
int
g_total_decode_time
=
0
;
void
Decode
(
std
::
pair
<
std
::
string
,
std
::
string
>
wav
,
bool
warmup
,
std
::
shared_ptr
<
wenet
::
DecodeOptions
>
g_decode_config
,
std
::
shared_ptr
<
wenet
::
FeaturePipelineConfig
>
g_feature_config
,
std
::
shared_ptr
<
wenet
::
DecodeResource
>
g_decode_resource
)
{
void
Decode
(
std
::
pair
<
std
::
string
,
std
::
string
>
wav
,
bool
warmup
,
std
::
shared_ptr
<
wenet
::
DecodeOptions
>
g_decode_config
,
std
::
shared_ptr
<
wenet
::
FeaturePipelineConfig
>
g_feature_config
)
{
std
::
shared_ptr
<
wenet
::
DecodeResource
>
g_decode_resource
=
wenet
::
InitDecodeResourceFromFlags
();
wenet
::
WavReader
wav_reader
(
wav
.
second
);
wenet
::
WavReader
wav_reader
(
wav
.
second
);
int
num_samples
=
wav_reader
.
num_samples
();
int
num_samples
=
wav_reader
.
num_samples
();
CHECK_EQ
(
wav_reader
.
sample_rate
(),
FLAGS_sample_rate
);
CHECK_EQ
(
wav_reader
.
sample_rate
(),
FLAGS_sample_rate
);
...
@@ -156,7 +157,7 @@ int main(int argc, char* argv[]) {
...
@@ -156,7 +157,7 @@ int main(int argc, char* argv[]) {
ThreadPool
pool
(
FLAGS_thread_num
);
ThreadPool
pool
(
FLAGS_thread_num
);
auto
wav
=
waves
[
0
];
auto
wav
=
waves
[
0
];
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
i
++
)
{
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
i
++
)
{
pool
.
enqueue
(
Decode
,
wav
,
true
,
g_decode_config
,
g_feature_config
,
g_decode_resource
);
pool
.
enqueue
(
Decode
,
wav
,
true
,
g_decode_config
,
g_feature_config
);
}
}
}
}
LOG
(
INFO
)
<<
"Warmup done."
;
LOG
(
INFO
)
<<
"Warmup done."
;
...
@@ -165,7 +166,7 @@ int main(int argc, char* argv[]) {
...
@@ -165,7 +166,7 @@ int main(int argc, char* argv[]) {
{
{
ThreadPool
pool
(
FLAGS_thread_num
);
ThreadPool
pool
(
FLAGS_thread_num
);
for
(
auto
&
wav
:
waves
)
{
for
(
auto
&
wav
:
waves
)
{
pool
.
enqueue
(
Decode
,
wav
,
false
,
g_decode_config
,
g_feature_config
,
g_decode_resource
);
pool
.
enqueue
(
Decode
,
wav
,
false
,
g_decode_config
,
g_feature_config
);
}
}
}
}
...
...
Src/decoder/params.h
View file @
3c4ea2c0
...
@@ -108,11 +108,15 @@ std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
...
@@ -108,11 +108,15 @@ std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
}
}
std
::
shared_ptr
<
DecodeResource
>
InitDecodeResourceFromFlags
()
{
std
::
shared_ptr
<
DecodeResource
>
InitDecodeResourceFromFlags
()
{
static
bool
isRegisterROCM
=
false
;
auto
resource
=
std
::
make_shared
<
DecodeResource
>
();
auto
resource
=
std
::
make_shared
<
DecodeResource
>
();
const
int
kNumGemmThreads
=
1
;
const
int
kNumGemmThreads
=
1
;
if
(
!
FLAGS_onnx_dir
.
empty
())
{
if
(
!
FLAGS_onnx_dir
.
empty
())
{
LOG
(
INFO
)
<<
"Reading onnx model "
;
LOG
(
INFO
)
<<
"Reading onnx model "
;
if
(
isRegisterROCM
==
false
)
{
OnnxAsrModel
::
InitEngineThreads
(
kNumGemmThreads
);
OnnxAsrModel
::
InitEngineThreads
(
kNumGemmThreads
);
isRegisterROCM
=
true
;
}
auto
model
=
std
::
make_shared
<
OnnxAsrModel
>
();
auto
model
=
std
::
make_shared
<
OnnxAsrModel
>
();
model
->
Read
(
FLAGS_onnx_dir
);
model
->
Read
(
FLAGS_onnx_dir
);
resource
->
model
=
model
;
resource
->
model
=
model
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment