Commit 3c4ea2c0 authored by wufan3's avatar wufan3
Browse files

fix BUG[92942]:Improving wenet inference performance by creating a separate session for each thread

parent 61aeca13
...@@ -17,7 +17,7 @@ DEFINE_string(wav_scp, "", "input wav scp"); ...@@ -17,7 +17,7 @@ DEFINE_string(wav_scp, "", "input wav scp");
DEFINE_string(result, "./result", "result output file"); DEFINE_string(result, "./result", "result output file");
DEFINE_bool(continuous_decoding, false, "continuous decoding mode"); DEFINE_bool(continuous_decoding, false, "continuous decoding mode");
DEFINE_int32(thread_num, 1, "num of decode thread"); DEFINE_int32(thread_num, 1, "num of decode thread");
DEFINE_int32(warmup, 0, "num of warmup decode, 0 means no warmup"); DEFINE_int32(warmup, 1, "num of warmup decode, 0 means no warmup");
// std::shared_ptr<wenet::DecodeOptions> g_decode_config; // std::shared_ptr<wenet::DecodeOptions> g_decode_config;
// std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config; // std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config;
...@@ -28,7 +28,8 @@ std::mutex g_mutex; ...@@ -28,7 +28,8 @@ std::mutex g_mutex;
int g_total_waves_dur = 0; int g_total_waves_dur = 0;
int g_total_decode_time = 0; int g_total_decode_time = 0;
void Decode(std::pair<std::string, std::string> wav, bool warmup, std::shared_ptr<wenet::DecodeOptions> g_decode_config, std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config, std::shared_ptr<wenet::DecodeResource> g_decode_resource) { void Decode(std::pair<std::string, std::string> wav, bool warmup, std::shared_ptr<wenet::DecodeOptions> g_decode_config, std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config) {
std::shared_ptr<wenet::DecodeResource> g_decode_resource = wenet::InitDecodeResourceFromFlags();
wenet::WavReader wav_reader(wav.second); wenet::WavReader wav_reader(wav.second);
int num_samples = wav_reader.num_samples(); int num_samples = wav_reader.num_samples();
CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate); CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
...@@ -156,7 +157,7 @@ int main(int argc, char* argv[]) { ...@@ -156,7 +157,7 @@ int main(int argc, char* argv[]) {
ThreadPool pool(FLAGS_thread_num); ThreadPool pool(FLAGS_thread_num);
auto wav = waves[0]; auto wav = waves[0];
for (int i = 0; i < FLAGS_warmup; i++) { for (int i = 0; i < FLAGS_warmup; i++) {
pool.enqueue(Decode, wav, true, g_decode_config, g_feature_config, g_decode_resource); pool.enqueue(Decode, wav, true, g_decode_config, g_feature_config);
} }
} }
LOG(INFO) << "Warmup done."; LOG(INFO) << "Warmup done.";
...@@ -165,7 +166,7 @@ int main(int argc, char* argv[]) { ...@@ -165,7 +166,7 @@ int main(int argc, char* argv[]) {
{ {
ThreadPool pool(FLAGS_thread_num); ThreadPool pool(FLAGS_thread_num);
for (auto& wav : waves) { for (auto& wav : waves) {
pool.enqueue(Decode, wav, false, g_decode_config, g_feature_config, g_decode_resource); pool.enqueue(Decode, wav, false, g_decode_config, g_feature_config);
} }
} }
......
...@@ -108,11 +108,15 @@ std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() { ...@@ -108,11 +108,15 @@ std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
} }
std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() { std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
static bool isRegisterROCM = false;
auto resource = std::make_shared<DecodeResource>(); auto resource = std::make_shared<DecodeResource>();
const int kNumGemmThreads = 1; const int kNumGemmThreads = 1;
if (!FLAGS_onnx_dir.empty()) { if (!FLAGS_onnx_dir.empty()) {
LOG(INFO) << "Reading onnx model "; LOG(INFO) << "Reading onnx model ";
if (isRegisterROCM == false) {
OnnxAsrModel::InitEngineThreads(kNumGemmThreads); OnnxAsrModel::InitEngineThreads(kNumGemmThreads);
isRegisterROCM = true;
}
auto model = std::make_shared<OnnxAsrModel>(); auto model = std::make_shared<OnnxAsrModel>();
model->Read(FLAGS_onnx_dir); model->Read(FLAGS_onnx_dir);
resource->model = model; resource->model = model;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment