rapidasr.h 1.29 KB
Newer Older
SWHL's avatar
SWHL committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#pragma once

class CQmASRRecog
{
private:

	Ort::Session* m_session_encoder=nullptr;
	Ort::Session* m_session_decoder = nullptr;
	Ort::Env envDecoder = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "QmASR_decoder");
	Ort::Env envEncoder = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "QmASR_encoder");
	Ort::SessionOptions sessionOptions = Ort::SessionOptions();

	string m_strConfig, m_strDict;

	vector<string> m_vecEncInputName, m_vecEncOutputName, m_vecDecInputName, m_vecDecOutputName;
	vector<const char *> m_strEncInputName, m_strEncOutputName, m_strDecInputName, m_strDecOutputName;
	bool m_bIsLoaded = false;
	vector<std::string> m_Vocabulary;

	float m_reverse_weight=0.f;  // train.yamlжȡ

public :
	CQmASRRecog(const char * szModelDir, int nThread);
	CQmASRRecog(const char* szEncoder, const char* szDecoder, const char* szDict, const char* szConfig, int nThread);
	~CQmASRRecog();

	bool LoadModel(const char* szEncoder, const char* szDecoder, const char* szDict, const char* szConfig, int nNumThread);
	bool LoadModel(const char* szModelDir, int nNumThread);
	bool IsLoaded();

	int  ExtractFeature(vector<float>& wav, std::vector<std::vector<float>>& feats ,wenet::FeaturePipelineConfig& config); // ȡ
	PRAPIDASR_RECOG_RESULT DoRecognize(vector<vector<float>>& feats, RAPIDASR_MODE Mode = RPASRM_CTC_GREEDY_SEARCH);


};