#ifndef DECODER_ONNX_ASR_MODEL_H_ #define DECODER_ONNX_ASR_MODEL_H_ #include #include #include #include #include "decoder/asr_model.h" #include "utils/log.h" #include "utils/utils.h" namespace wenet { class OnnxAsrModel : public AsrModel { public: static void InitEngineThreads(int num_threads = 1); public: OnnxAsrModel() = default; OnnxAsrModel(const OnnxAsrModel& other); void Read(const std::string& model_dir); void Reset() override; void AttentionRescoring(const std::vector>& hyps, float reverse_weight, std::vector* rescoring_score) override; std::shared_ptr Copy() const override; void GetInputOutputInfo(const std::shared_ptr& session, std::vector* in_names, std::vector* out_names); protected: void ForwardEncoderFunc(const std::vector>& chunk_feats, std::vector>* ctc_prob) override; float ComputeAttentionScore(const float* prob, const std::vector& hyp, int eos, int decode_out_len); private: int encoder_output_size_ = 0; int num_blocks_ = 0; int cnn_module_kernel_ = 0; int head_ = 0; // sessions // NOTE(Mddct): The Env holds the logging state used by all other objects. // One Env must be created before using any other Onnxruntime functionality. static Ort::Env env_; // shared environment across threads. static Ort::SessionOptions session_options_; std::shared_ptr encoder_session_ = nullptr; std::shared_ptr rescore_session_ = nullptr; std::shared_ptr ctc_session_ = nullptr; // node names std::vector encoder_in_names_, encoder_out_names_; std::vector ctc_in_names_, ctc_out_names_; std::vector rescore_in_names_, rescore_out_names_; // caches Ort::Value att_cache_ort_{nullptr}; Ort::Value cnn_cache_ort_{nullptr}; std::vector encoder_outs_; std::vector att_cache_; std::vector cnn_cache_; }; } // namespace wenet #endif // DECODER_ONNX_ASR_MODEL_H_