Commit 978a6781 authored by mayong's avatar mayong
Browse files

update files

parent 3aad0a8f
...@@ -43,7 +43,7 @@ class Audio { ...@@ -43,7 +43,7 @@ class Audio {
Audio(int data_type, int size); Audio(int data_type, int size);
~Audio(); ~Audio();
void disp(); void disp();
void loadwav(const char *filename); bool loadwav(const char *filename);
int fetch_chunck(float *&dout, int len); int fetch_chunck(float *&dout, int len);
int fetch(float *&dout, int &len, int &flag); int fetch(float *&dout, int &len, int &flag);
void padding(); void padding();
......
...@@ -112,7 +112,7 @@ void Audio::disp() ...@@ -112,7 +112,7 @@ void Audio::disp()
speech_len); speech_len);
} }
void Audio::loadwav(const char *filename) bool Audio::loadwav(const char *filename)
{ {
if (speech_buff != NULL) { if (speech_buff != NULL) {
...@@ -124,6 +124,8 @@ void Audio::loadwav(const char *filename) ...@@ -124,6 +124,8 @@ void Audio::loadwav(const char *filename)
FILE *fp; FILE *fp;
fp = fopen(filename, "rb"); fp = fopen(filename, "rb");
if (fp == nullptr)
return false;
fseek(fp, 0, SEEK_END); fseek(fp, 0, SEEK_END);
uint32_t nFileLen = ftell(fp); uint32_t nFileLen = ftell(fp);
fseek(fp, 44, SEEK_SET); fseek(fp, 44, SEEK_SET);
...@@ -150,6 +152,7 @@ void Audio::loadwav(const char *filename) ...@@ -150,6 +152,7 @@ void Audio::loadwav(const char *filename)
AudioFrame *frame = new AudioFrame(speech_len); AudioFrame *frame = new AudioFrame(speech_len);
frame_queue.push(frame); frame_queue.push(frame);
return true;
} }
int Audio::fetch_chunck(float *&dout, int len) int Audio::fetch_chunck(float *&dout, int len)
......
...@@ -162,28 +162,28 @@ string ModelImp::forward(float* din, int len, int flag) ...@@ -162,28 +162,28 @@ string ModelImp::forward(float* din, int len, int flag)
input_onnx.emplace_back(std::move(onnx_feats)); input_onnx.emplace_back(std::move(onnx_feats));
input_onnx.emplace_back(std::move(onnx_feats_len)); input_onnx.emplace_back(std::move(onnx_feats_len));
//auto output = m_session_encoder->Run(run_option, string result;
// m_strEncInputName.data(), try {
// input_onnx.data(),
// m_strEncInputName.size(),
// m_strEncOutputName.data(),
// m_strEncOutputName.size()
//);
auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
//assert(outputTensor.size() == 1 && outputTensor[0].IsTensor());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
auto outputTensor = m_session->Run(run_option, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size());
//assert(outputTensor.size() == 1 && outputTensor[0].IsTensor()); int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
//float* floatSize = outputTensor[1].GetTensorMutableData<float>();
//std::vector<float> out_data(floatArray, floatArray + outputCount);
result = greedy_search(floatData, *encoder_out_lens);
}
catch (...)
{
result = "";
}
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
float* floatData = outputTensor[0].GetTensorMutableData<float>();
auto encoder_out_lens = outputTensor[1].GetTensorMutableData<int64_t>();
//float* floatSize = outputTensor[1].GetTensorMutableData<float>();
//std::vector<float> out_data(floatArray, floatArray + outputCount);
string result = greedy_search(floatData, *encoder_out_lens);
if(in) if(in)
delete in; delete in;
......
...@@ -20,7 +20,11 @@ int main(int argc, char *argv[]) ...@@ -20,7 +20,11 @@ int main(int argc, char *argv[])
} }
struct timeval start, end; struct timeval start, end;
Audio audio(0); Audio audio(0);
audio.loadwav(argv[2]); if (!audio.loadwav(argv[2]))
{
printf("cannot load %s\n", argv[2]);
return -1;
}
audio.disp(); audio.disp();
gettimeofday(&start, NULL); gettimeofday(&start, NULL);
Model *mm = create_model(argv[1], 3); Model *mm = create_model(argv[1], 3);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment