Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
wenet_onnxruntime
Commits
abc11f55
Commit
abc11f55
authored
Aug 20, 2024
by
yangql
Browse files
增加对dtk24.04.1的支持,以及对k100_AI的支持。
parent
58043336
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
69 additions
and
69 deletions
+69
-69
Src/decoder/onnx_asr_model.h
Src/decoder/onnx_asr_model.h
+69
-69
No files found.
Src/decoder/onnx_asr_model.h
View file @
abc11f55
#ifndef DECODER_ONNX_ASR_MODEL_H_
#define DECODER_ONNX_ASR_MODEL_H_
#include <memory>
#include <string>
#include <vector>
#include <onnxruntime
/core/session/onnxruntime
_cxx_api.h>
// NOLINT
#include "decoder/asr_model.h"
#include "utils/log.h"
#include "utils/utils.h"
namespace
wenet
{
class
OnnxAsrModel
:
public
AsrModel
{
public:
static
void
InitEngineThreads
(
int
num_threads
=
1
);
public:
OnnxAsrModel
()
=
default
;
OnnxAsrModel
(
const
OnnxAsrModel
&
other
);
void
Read
(
const
std
::
string
&
model_dir
);
void
Reset
()
override
;
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
override
;
std
::
shared_ptr
<
AsrModel
>
Copy
()
const
override
;
void
GetInputOutputInfo
(
const
std
::
shared_ptr
<
Ort
::
Session
>&
session
,
std
::
vector
<
const
char
*>*
in_names
,
std
::
vector
<
const
char
*>*
out_names
);
protected:
void
ForwardEncoderFunc
(
const
std
::
vector
<
std
::
vector
<
float
>>&
chunk_feats
,
std
::
vector
<
std
::
vector
<
float
>>*
ctc_prob
)
override
;
float
ComputeAttentionScore
(
const
float
*
prob
,
const
std
::
vector
<
int
>&
hyp
,
int
eos
,
int
decode_out_len
);
private:
int
encoder_output_size_
=
0
;
int
num_blocks_
=
0
;
int
cnn_module_kernel_
=
0
;
int
head_
=
0
;
// sessions
// NOTE(Mddct): The Env holds the logging state used by all other objects.
// One Env must be created before using any other Onnxruntime functionality.
static
Ort
::
Env
env_
;
// shared environment across threads.
static
Ort
::
SessionOptions
session_options_
;
std
::
shared_ptr
<
Ort
::
Session
>
encoder_session_
=
nullptr
;
std
::
shared_ptr
<
Ort
::
Session
>
rescore_session_
=
nullptr
;
std
::
shared_ptr
<
Ort
::
Session
>
ctc_session_
=
nullptr
;
// node names
std
::
vector
<
const
char
*>
encoder_in_names_
,
encoder_out_names_
;
std
::
vector
<
const
char
*>
ctc_in_names_
,
ctc_out_names_
;
std
::
vector
<
const
char
*>
rescore_in_names_
,
rescore_out_names_
;
// caches
Ort
::
Value
att_cache_ort_
{
nullptr
};
Ort
::
Value
cnn_cache_ort_
{
nullptr
};
std
::
vector
<
Ort
::
Value
>
encoder_outs_
;
std
::
vector
<
float
>
att_cache_
;
std
::
vector
<
float
>
cnn_cache_
;
};
}
// namespace wenet
#endif // DECODER_ONNX_ASR_MODEL_H_
#ifndef DECODER_ONNX_ASR_MODEL_H_
#define DECODER_ONNX_ASR_MODEL_H_
#include <memory>
#include <string>
#include <vector>
#include <onnxruntime_cxx_api.h>
#include "decoder/asr_model.h"
#include "utils/log.h"
#include "utils/utils.h"
namespace
wenet
{
class
OnnxAsrModel
:
public
AsrModel
{
public:
static
void
InitEngineThreads
(
int
num_threads
=
1
);
public:
OnnxAsrModel
()
=
default
;
OnnxAsrModel
(
const
OnnxAsrModel
&
other
);
void
Read
(
const
std
::
string
&
model_dir
);
void
Reset
()
override
;
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
override
;
std
::
shared_ptr
<
AsrModel
>
Copy
()
const
override
;
void
GetInputOutputInfo
(
const
std
::
shared_ptr
<
Ort
::
Session
>&
session
,
std
::
vector
<
const
char
*>*
in_names
,
std
::
vector
<
const
char
*>*
out_names
);
protected:
void
ForwardEncoderFunc
(
const
std
::
vector
<
std
::
vector
<
float
>>&
chunk_feats
,
std
::
vector
<
std
::
vector
<
float
>>*
ctc_prob
)
override
;
float
ComputeAttentionScore
(
const
float
*
prob
,
const
std
::
vector
<
int
>&
hyp
,
int
eos
,
int
decode_out_len
);
private:
int
encoder_output_size_
=
0
;
int
num_blocks_
=
0
;
int
cnn_module_kernel_
=
0
;
int
head_
=
0
;
// sessions
// NOTE(Mddct): The Env holds the logging state used by all other objects.
// One Env must be created before using any other Onnxruntime functionality.
static
Ort
::
Env
env_
;
// shared environment across threads.
static
Ort
::
SessionOptions
session_options_
;
std
::
shared_ptr
<
Ort
::
Session
>
encoder_session_
=
nullptr
;
std
::
shared_ptr
<
Ort
::
Session
>
rescore_session_
=
nullptr
;
std
::
shared_ptr
<
Ort
::
Session
>
ctc_session_
=
nullptr
;
// node names
std
::
vector
<
const
char
*>
encoder_in_names_
,
encoder_out_names_
;
std
::
vector
<
const
char
*>
ctc_in_names_
,
ctc_out_names_
;
std
::
vector
<
const
char
*>
rescore_in_names_
,
rescore_out_names_
;
// caches
Ort
::
Value
att_cache_ort_
{
nullptr
};
Ort
::
Value
cnn_cache_ort_
{
nullptr
};
std
::
vector
<
Ort
::
Value
>
encoder_outs_
;
std
::
vector
<
float
>
att_cache_
;
std
::
vector
<
float
>
cnn_cache_
;
};
}
// namespace wenet
#endif // DECODER_ONNX_ASR_MODEL_H_
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment