Commit 3c4ea2c0 authored by wufan3's avatar wufan3
Browse files

fix BUG[92942]:Improving wenet inference performance by creating a separate session for each thread

parent 61aeca13
......@@ -17,7 +17,7 @@ DEFINE_string(wav_scp, "", "input wav scp");
DEFINE_string(result, "./result", "result output file");
DEFINE_bool(continuous_decoding, false, "continuous decoding mode");
DEFINE_int32(thread_num, 1, "num of decode thread");
DEFINE_int32(warmup, 0, "num of warmup decode, 0 means no warmup");
DEFINE_int32(warmup, 1, "num of warmup decode, 0 means no warmup");
// std::shared_ptr<wenet::DecodeOptions> g_decode_config;
// std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config;
......@@ -28,7 +28,8 @@ std::mutex g_mutex;
int g_total_waves_dur = 0;
int g_total_decode_time = 0;
void Decode(std::pair<std::string, std::string> wav, bool warmup, std::shared_ptr<wenet::DecodeOptions> g_decode_config, std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config, std::shared_ptr<wenet::DecodeResource> g_decode_resource) {
void Decode(std::pair<std::string, std::string> wav, bool warmup, std::shared_ptr<wenet::DecodeOptions> g_decode_config, std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config) {
std::shared_ptr<wenet::DecodeResource> g_decode_resource = wenet::InitDecodeResourceFromFlags();
wenet::WavReader wav_reader(wav.second);
int num_samples = wav_reader.num_samples();
CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
......@@ -156,7 +157,7 @@ int main(int argc, char* argv[]) {
ThreadPool pool(FLAGS_thread_num);
auto wav = waves[0];
for (int i = 0; i < FLAGS_warmup; i++) {
pool.enqueue(Decode, wav, true, g_decode_config, g_feature_config, g_decode_resource);
pool.enqueue(Decode, wav, true, g_decode_config, g_feature_config);
}
}
LOG(INFO) << "Warmup done.";
......@@ -165,7 +166,7 @@ int main(int argc, char* argv[]) {
{
ThreadPool pool(FLAGS_thread_num);
for (auto& wav : waves) {
pool.enqueue(Decode, wav, false, g_decode_config, g_feature_config, g_decode_resource);
pool.enqueue(Decode, wav, false, g_decode_config, g_feature_config);
}
}
......
......@@ -108,11 +108,15 @@ std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
}
std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
static bool isRegisterROCM = false;
auto resource = std::make_shared<DecodeResource>();
const int kNumGemmThreads = 1;
if (!FLAGS_onnx_dir.empty()) {
LOG(INFO) << "Reading onnx model ";
if (isRegisterROCM == false) {
OnnxAsrModel::InitEngineThreads(kNumGemmThreads);
isRegisterROCM = true;
}
auto model = std::make_shared<OnnxAsrModel>();
model->Read(FLAGS_onnx_dir);
resource->model = model;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment