"cacheflow/frontend/simple_frontend.py" did not exist on "d359cda5fae1c9a6fe54ba12b940572edfbf87ac"
Commit 764b3a75 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CTC_WFST_BEAM_SEARCH_H_
#define DECODER_CTC_WFST_BEAM_SEARCH_H_
#include <memory>
#include <vector>
#include "decoder/context_graph.h"
#include "decoder/search_interface.h"
#include "kaldi/decoder/lattice-faster-online-decoder.h"
#include "utils/utils.h"
namespace wenet {
class DecodableTensorScaled : public kaldi::DecodableInterface {
public:
explicit DecodableTensorScaled(float scale = 1.0) : scale_(scale) { Reset(); }
void Reset();
int32 NumFramesReady() const override { return num_frames_ready_; }
bool IsLastFrame(int32 frame) const override;
float LogLikelihood(int32 frame, int32 index) override;
int32 NumIndices() const override;
void AcceptLoglikes(const std::vector<float>& logp);
void SetFinish() { done_ = true; }
private:
int num_frames_ready_ = 0;
float scale_ = 1.0;
bool done_ = false;
std::vector<float> logp_;
};
// LatticeFasterDecoderConfig has the following key members
// beam: decoding beam
// max_active: Decoder max active states
// lattice_beam: Lattice generation beam
struct CtcWfstBeamSearchOptions : public kaldi::LatticeFasterDecoderConfig {
float acoustic_scale = 1.0;
float nbest = 10;
// When blank score is greater than this thresh, skip the frame in viterbi
// search
float blank_skip_thresh = 0.98;
float blank_scale = 1.0;
};
class CtcWfstBeamSearch : public SearchInterface {
public:
explicit CtcWfstBeamSearch(
const fst::Fst<fst::StdArc>& fst, const CtcWfstBeamSearchOptions& opts,
const std::shared_ptr<ContextGraph>& context_graph);
void Search(const std::vector<std::vector<float>>& logp) override;
void Reset() override;
void FinalizeSearch() override;
SearchType Type() const override { return SearchType::kWfstBeamSearch; }
// For CTC prefix beam search, both inputs and outputs are hypotheses_
const std::vector<std::vector<int>>& Inputs() const override {
return inputs_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return outputs_;
}
const std::vector<float>& Likelihood() const override { return likelihood_; }
const std::vector<std::vector<int>>& Times() const override { return times_; }
private:
// Sub one and remove <blank>
void ConvertToInputs(const std::vector<int>& alignment,
std::vector<int>* input,
std::vector<int>* time = nullptr);
void RemoveContinuousTags(std::vector<int>* output);
int num_frames_ = 0;
std::vector<int> decoded_frames_mapping_;
int last_best_ = 0; // last none blank best id
std::vector<float> last_frame_prob_;
bool is_last_frame_blank_ = false;
std::vector<std::vector<int>> inputs_, outputs_;
std::vector<float> likelihood_;
std::vector<std::vector<int>> times_;
DecodableTensorScaled decodable_;
kaldi::LatticeFasterOnlineDecoder decoder_;
std::shared_ptr<ContextGraph> context_graph_;
const CtcWfstBeamSearchOptions& opts_;
};
} // namespace wenet
#endif // DECODER_CTC_WFST_BEAM_SEARCH_H_
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 ZeXuan Li (lizexuan@huya.com)
// Xingchen Song(sxc19@mails.tsinghua.edu.cn)
// hamddct@gmail.com (Mddct)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/onnx_asr_model.h"
#include <algorithm>
#include <memory>
#include <utility>
#include "utils/string.h"
namespace wenet {
Ort::Env OnnxAsrModel::env_ = Ort::Env(ORT_LOGGING_LEVEL_WARNING, "");
Ort::SessionOptions OnnxAsrModel::session_options_ = Ort::SessionOptions();
void OnnxAsrModel::InitEngineThreads(int num_threads) {
session_options_.SetIntraOpNumThreads(num_threads);
}
void OnnxAsrModel::GetInputOutputInfo(
const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names, std::vector<const char*>* out_names) {
Ort::AllocatorWithDefaultOptions allocator;
// Input info
int num_nodes = session->GetInputCount();
in_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetInputName(i, allocator);
Ort::TypeInfo type_info = session->GetInputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
std::vector<int64_t> node_dims = tensor_info.GetShape();
std::stringstream shape;
for (auto j : node_dims) {
shape << j;
shape << " ";
}
LOG(INFO) << "\tInput " << i << " : name=" << name << " type=" << type
<< " dims=" << shape.str();
(*in_names)[i] = name;
}
// Output info
num_nodes = session->GetOutputCount();
out_names->resize(num_nodes);
for (int i = 0; i < num_nodes; ++i) {
char* name = session->GetOutputName(i, allocator);
Ort::TypeInfo type_info = session->GetOutputTypeInfo(i);
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
ONNXTensorElementDataType type = tensor_info.GetElementType();
std::vector<int64_t> node_dims = tensor_info.GetShape();
std::stringstream shape;
for (auto j : node_dims) {
shape << j;
shape << " ";
}
LOG(INFO) << "\tOutput " << i << " : name=" << name << " type=" << type
<< " dims=" << shape.str();
(*out_names)[i] = name;
}
}
void OnnxAsrModel::Read(const std::string& model_dir) {
std::string encoder_onnx_path = model_dir + "/encoder.onnx";
std::string rescore_onnx_path = model_dir + "/decoder.onnx";
std::string ctc_onnx_path = model_dir + "/ctc.onnx";
// 1. Load sessions
try {
#ifdef _MSC_VER
encoder_session_ = std::make_shared<Ort::Session>(
env_, ToWString(encoder_onnx_path).c_str(), session_options_);
rescore_session_ = std::make_shared<Ort::Session>(
env_, ToWString(rescore_onnx_path).c_str(), session_options_);
ctc_session_ = std::make_shared<Ort::Session>(
env_, ToWString(ctc_onnx_path).c_str(), session_options_);
#else
encoder_session_ = std::make_shared<Ort::Session>(
env_, encoder_onnx_path.c_str(), session_options_);
rescore_session_ = std::make_shared<Ort::Session>(
env_, rescore_onnx_path.c_str(), session_options_);
ctc_session_ = std::make_shared<Ort::Session>(env_, ctc_onnx_path.c_str(),
session_options_);
#endif
} catch (std::exception const& e) {
LOG(ERROR) << "error when load onnx model: " << e.what();
exit(0);
}
// 2. Read metadata
auto model_metadata = encoder_session_->GetModelMetadata();
Ort::AllocatorWithDefaultOptions allocator;
encoder_output_size_ =
atoi(model_metadata.LookupCustomMetadataMap("output_size", allocator));
num_blocks_ =
atoi(model_metadata.LookupCustomMetadataMap("num_blocks", allocator));
head_ = atoi(model_metadata.LookupCustomMetadataMap("head", allocator));
cnn_module_kernel_ = atoi(
model_metadata.LookupCustomMetadataMap("cnn_module_kernel", allocator));
subsampling_rate_ = atoi(
model_metadata.LookupCustomMetadataMap("subsampling_rate", allocator));
right_context_ =
atoi(model_metadata.LookupCustomMetadataMap("right_context", allocator));
sos_ = atoi(model_metadata.LookupCustomMetadataMap("sos_symbol", allocator));
eos_ = atoi(model_metadata.LookupCustomMetadataMap("eos_symbol", allocator));
is_bidirectional_decoder_ = atoi(model_metadata.LookupCustomMetadataMap(
"is_bidirectional_decoder", allocator));
chunk_size_ =
atoi(model_metadata.LookupCustomMetadataMap("chunk_size", allocator));
num_left_chunks_ =
atoi(model_metadata.LookupCustomMetadataMap("left_chunks", allocator));
LOG(INFO) << "Onnx Model Info:";
LOG(INFO) << "\tencoder_output_size " << encoder_output_size_;
LOG(INFO) << "\tnum_blocks " << num_blocks_;
LOG(INFO) << "\thead " << head_;
LOG(INFO) << "\tcnn_module_kernel " << cnn_module_kernel_;
LOG(INFO) << "\tsubsampling_rate " << subsampling_rate_;
LOG(INFO) << "\tright_context " << right_context_;
LOG(INFO) << "\tsos " << sos_;
LOG(INFO) << "\teos " << eos_;
LOG(INFO) << "\tis bidirectional decoder " << is_bidirectional_decoder_;
LOG(INFO) << "\tchunk_size " << chunk_size_;
LOG(INFO) << "\tnum_left_chunks " << num_left_chunks_;
// 3. Read model nodes
LOG(INFO) << "Onnx Encoder:";
GetInputOutputInfo(encoder_session_, &encoder_in_names_, &encoder_out_names_);
LOG(INFO) << "Onnx CTC:";
GetInputOutputInfo(ctc_session_, &ctc_in_names_, &ctc_out_names_);
LOG(INFO) << "Onnx Rescore:";
GetInputOutputInfo(rescore_session_, &rescore_in_names_, &rescore_out_names_);
}
OnnxAsrModel::OnnxAsrModel(const OnnxAsrModel& other) {
// metadatas
encoder_output_size_ = other.encoder_output_size_;
num_blocks_ = other.num_blocks_;
head_ = other.head_;
cnn_module_kernel_ = other.cnn_module_kernel_;
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidirectional_decoder_ = other.is_bidirectional_decoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// sessions
encoder_session_ = other.encoder_session_;
ctc_session_ = other.ctc_session_;
rescore_session_ = other.rescore_session_;
// node names
encoder_in_names_ = other.encoder_in_names_;
encoder_out_names_ = other.encoder_out_names_;
ctc_in_names_ = other.ctc_in_names_;
ctc_out_names_ = other.ctc_out_names_;
rescore_in_names_ = other.rescore_in_names_;
rescore_out_names_ = other.rescore_out_names_;
}
std::shared_ptr<AsrModel> OnnxAsrModel::Copy() const {
auto asr_model = std::make_shared<OnnxAsrModel>(*this);
// Reset the inner states for new decoding
asr_model->Reset();
return asr_model;
}
void OnnxAsrModel::Reset() {
offset_ = 0;
encoder_outs_.clear();
cached_feature_.clear();
// Reset att_cache
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
if (num_left_chunks_ > 0) {
int required_cache_size = chunk_size_ * num_left_chunks_;
offset_ = required_cache_size;
att_cache_.resize(num_blocks_ * head_ * required_cache_size *
encoder_output_size_ / head_ * 2,
0.0);
const int64_t att_cache_shape[] = {num_blocks_, head_, required_cache_size,
encoder_output_size_ / head_ * 2};
att_cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, att_cache_.data(), att_cache_.size(), att_cache_shape, 4);
} else {
att_cache_.resize(0, 0.0);
const int64_t att_cache_shape[] = {num_blocks_, head_, 0,
encoder_output_size_ / head_ * 2};
att_cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, att_cache_.data(), att_cache_.size(), att_cache_shape, 4);
}
// Reset cnn_cache
cnn_cache_.resize(
num_blocks_ * encoder_output_size_ * (cnn_module_kernel_ - 1), 0.0);
const int64_t cnn_cache_shape[] = {num_blocks_, 1, encoder_output_size_,
cnn_module_kernel_ - 1};
cnn_cache_ort_ = Ort::Value::CreateTensor<float>(
memory_info, cnn_cache_.data(), cnn_cache_.size(), cnn_cache_shape, 4);
}
void OnnxAsrModel::ForwardEncoderFunc(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* out_prob) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
// 1. Prepare onnx required data, splice cached_feature_ and chunk_feats
// chunk
int num_frames = cached_feature_.size() + chunk_feats.size();
const int feature_dim = chunk_feats[0].size();
std::vector<float> feats;
for (size_t i = 0; i < cached_feature_.size(); ++i) {
feats.insert(feats.end(), cached_feature_[i].begin(),
cached_feature_[i].end());
}
for (size_t i = 0; i < chunk_feats.size(); ++i) {
feats.insert(feats.end(), chunk_feats[i].begin(), chunk_feats[i].end());
}
const int64_t feats_shape[3] = {1, num_frames, feature_dim};
Ort::Value feats_ort = Ort::Value::CreateTensor<float>(
memory_info, feats.data(), feats.size(), feats_shape, 3);
// offset
int64_t offset_int64 = static_cast<int64_t>(offset_);
Ort::Value offset_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, &offset_int64, 1, std::vector<int64_t>{}.data(), 0);
// required_cache_size
int64_t required_cache_size = chunk_size_ * num_left_chunks_;
Ort::Value required_cache_size_ort = Ort::Value::CreateTensor<int64_t>(
memory_info, &required_cache_size, 1, std::vector<int64_t>{}.data(), 0);
// att_mask
Ort::Value att_mask_ort{nullptr};
std::vector<uint8_t> att_mask(required_cache_size + chunk_size_, 1);
if (num_left_chunks_ > 0) {
int chunk_idx = offset_ / chunk_size_ - num_left_chunks_;
if (chunk_idx < num_left_chunks_) {
for (int i = 0; i < (num_left_chunks_ - chunk_idx) * chunk_size_; ++i) {
att_mask[i] = 0;
}
}
const int64_t att_mask_shape[] = {1, 1, required_cache_size + chunk_size_};
att_mask_ort = Ort::Value::CreateTensor<bool>(
memory_info, reinterpret_cast<bool*>(att_mask.data()), att_mask.size(),
att_mask_shape, 3);
}
// 2. Encoder chunk forward
std::vector<Ort::Value> inputs;
for (auto name : encoder_in_names_) {
if (!strcmp(name, "chunk")) {
inputs.emplace_back(std::move(feats_ort));
} else if (!strcmp(name, "offset")) {
inputs.emplace_back(std::move(offset_ort));
} else if (!strcmp(name, "required_cache_size")) {
inputs.emplace_back(std::move(required_cache_size_ort));
} else if (!strcmp(name, "att_cache")) {
inputs.emplace_back(std::move(att_cache_ort_));
} else if (!strcmp(name, "cnn_cache")) {
inputs.emplace_back(std::move(cnn_cache_ort_));
} else if (!strcmp(name, "att_mask")) {
inputs.emplace_back(std::move(att_mask_ort));
}
}
std::vector<Ort::Value> ort_outputs = encoder_session_->Run(
Ort::RunOptions{nullptr}, encoder_in_names_.data(), inputs.data(),
inputs.size(), encoder_out_names_.data(), encoder_out_names_.size());
offset_ += static_cast<int>(
ort_outputs[0].GetTensorTypeAndShapeInfo().GetShape()[1]);
att_cache_ort_ = std::move(ort_outputs[1]);
cnn_cache_ort_ = std::move(ort_outputs[2]);
std::vector<Ort::Value> ctc_inputs;
ctc_inputs.emplace_back(std::move(ort_outputs[0]));
std::vector<Ort::Value> ctc_ort_outputs = ctc_session_->Run(
Ort::RunOptions{nullptr}, ctc_in_names_.data(), ctc_inputs.data(),
ctc_inputs.size(), ctc_out_names_.data(), ctc_out_names_.size());
encoder_outs_.push_back(std::move(ctc_inputs[0]));
float* logp_data = ctc_ort_outputs[0].GetTensorMutableData<float>();
auto type_info = ctc_ort_outputs[0].GetTensorTypeAndShapeInfo();
int num_outputs = type_info.GetShape()[1];
int output_dim = type_info.GetShape()[2];
out_prob->resize(num_outputs);
for (int i = 0; i < num_outputs; i++) {
(*out_prob)[i].resize(output_dim);
memcpy((*out_prob)[i].data(), logp_data + i * output_dim,
sizeof(float) * output_dim);
}
}
float OnnxAsrModel::ComputeAttentionScore(const float* prob,
const std::vector<int>& hyp, int eos,
int decode_out_len) {
float score = 0.0f;
for (size_t j = 0; j < hyp.size(); ++j) {
score += *(prob + j * decode_out_len + hyp[j]);
}
score += *(prob + hyp.size() * decode_out_len + eos);
return score;
}
void OnnxAsrModel::AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) {
Ort::MemoryInfo memory_info =
Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);
if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}
std::vector<int64_t> hyps_lens;
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_lens.emplace_back(static_cast<int64_t>(length));
}
std::vector<float> rescore_input;
int encoder_len = 0;
for (int i = 0; i < encoder_outs_.size(); i++) {
float* encoder_outs_data = encoder_outs_[i].GetTensorMutableData<float>();
auto type_info = encoder_outs_[i].GetTensorTypeAndShapeInfo();
for (int j = 0; j < type_info.GetElementCount(); j++) {
rescore_input.emplace_back(encoder_outs_data[j]);
}
encoder_len += type_info.GetShape()[1];
}
const int64_t decode_input_shape[] = {1, encoder_len, encoder_output_size_};
std::vector<int64_t> hyps_pad;
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_pad.emplace_back(sos_);
size_t j = 0;
for (; j < hyp.size(); ++j) {
hyps_pad.emplace_back(hyp[j]);
}
if (j == max_hyps_len - 1) {
continue;
}
for (; j < max_hyps_len - 1; ++j) {
hyps_pad.emplace_back(0);
}
}
const int64_t hyps_pad_shape[] = {num_hyps, max_hyps_len};
const int64_t hyps_lens_shape[] = {num_hyps};
Ort::Value decode_input_tensor_ = Ort::Value::CreateTensor<float>(
memory_info, rescore_input.data(), rescore_input.size(),
decode_input_shape, 3);
Ort::Value hyps_pad_tensor_ = Ort::Value::CreateTensor<int64_t>(
memory_info, hyps_pad.data(), hyps_pad.size(), hyps_pad_shape, 2);
Ort::Value hyps_lens_tensor_ = Ort::Value::CreateTensor<int64_t>(
memory_info, hyps_lens.data(), hyps_lens.size(), hyps_lens_shape, 1);
std::vector<Ort::Value> rescore_inputs;
rescore_inputs.emplace_back(std::move(hyps_pad_tensor_));
rescore_inputs.emplace_back(std::move(hyps_lens_tensor_));
rescore_inputs.emplace_back(std::move(decode_input_tensor_));
std::vector<Ort::Value> rescore_outputs = rescore_session_->Run(
Ort::RunOptions{nullptr}, rescore_in_names_.data(), rescore_inputs.data(),
rescore_inputs.size(), rescore_out_names_.data(),
rescore_out_names_.size());
float* decoder_outs_data = rescore_outputs[0].GetTensorMutableData<float>();
float* r_decoder_outs_data = rescore_outputs[1].GetTensorMutableData<float>();
auto type_info = rescore_outputs[0].GetTensorTypeAndShapeInfo();
int decode_out_len = type_info.GetShape()[2];
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left to right decoder score
score = ComputeAttentionScore(
decoder_outs_data + max_hyps_len * decode_out_len * i, hyp, eos_,
decode_out_len);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidirectional_decoder_ && reverse_weight > 0) {
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(
r_decoder_outs_data + max_hyps_len * decode_out_len * i, r_hyp, eos_,
decode_out_len);
}
// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}
} // namespace wenet
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 ZeXuan Li (lizexuan@huya.com)
// Xingchen Song(sxc19@mails.tsinghua.edu.cn)
// hamddct@gmail.com (Mddct)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_ONNX_ASR_MODEL_H_
#define DECODER_ONNX_ASR_MODEL_H_
#include <memory>
#include <string>
#include <vector>
#include "onnxruntime_cxx_api.h" // NOLINT
#include "decoder/asr_model.h"
#include "utils/log.h"
#include "utils/utils.h"
namespace wenet {
class OnnxAsrModel : public AsrModel {
public:
static void InitEngineThreads(int num_threads = 1);
public:
OnnxAsrModel() = default;
OnnxAsrModel(const OnnxAsrModel& other);
void Read(const std::string& model_dir);
void Reset() override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
std::shared_ptr<AsrModel> Copy() const override;
void GetInputOutputInfo(const std::shared_ptr<Ort::Session>& session,
std::vector<const char*>* in_names,
std::vector<const char*>* out_names);
protected:
void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob) override;
float ComputeAttentionScore(const float* prob, const std::vector<int>& hyp,
int eos, int decode_out_len);
private:
int encoder_output_size_ = 0;
int num_blocks_ = 0;
int cnn_module_kernel_ = 0;
int head_ = 0;
// sessions
// NOTE(Mddct): The Env holds the logging state used by all other objects.
// One Env must be created before using any other Onnxruntime functionality.
static Ort::Env env_; // shared environment across threads.
static Ort::SessionOptions session_options_;
std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
std::shared_ptr<Ort::Session> rescore_session_ = nullptr;
std::shared_ptr<Ort::Session> ctc_session_ = nullptr;
// node names
std::vector<const char*> encoder_in_names_, encoder_out_names_;
std::vector<const char*> ctc_in_names_, ctc_out_names_;
std::vector<const char*> rescore_in_names_, rescore_out_names_;
// caches
Ort::Value att_cache_ort_{nullptr};
Ort::Value cnn_cache_ort_{nullptr};
std::vector<Ort::Value> encoder_outs_;
// NOTE: Instead of making a copy of the xx_cache, ONNX only maintains
// its data pointer when initializing xx_cache_ort (see https://github.com/
// microsoft/onnxruntime/blob/master/onnxruntime/core/framework
// /tensor.cc#L102-L129), so we need the following variables to keep
// our data "alive" during the lifetime of decoder.
std::vector<float> att_cache_;
std::vector<float> cnn_cache_;
};
} // namespace wenet
#endif // DECODER_ONNX_ASR_MODEL_H_
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_PARAMS_H_
#define DECODER_PARAMS_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "decoder/asr_decoder.h"
#ifdef USE_ONNX
#include "decoder/onnx_asr_model.h"
#endif
#ifdef USE_TORCH
#include "decoder/torch_asr_model.h"
#endif
#ifdef USE_XPU
#include "xpu/xpu_asr_model.h"
#endif
#ifdef USE_BPU
#include "bpu/bpu_asr_model.h"
#endif
#include "frontend/feature_pipeline.h"
#include "post_processor/post_processor.h"
#include "utils/flags.h"
#include "utils/string.h"
DEFINE_int32(device_id, 0, "set XPU DeviceID for ASR model");
// TorchAsrModel flags
DEFINE_string(model_path, "", "pytorch exported model path");
// OnnxAsrModel flags
DEFINE_string(onnx_dir, "", "directory where the onnx model is saved");
// XPUAsrModel flags
DEFINE_string(xpu_model_dir, "",
"directory where the XPU model and weights is saved");
// BPUAsrModel flags
DEFINE_string(bpu_model_dir, "",
"directory where the HORIZON BPU model is saved");
// FeaturePipelineConfig flags
DEFINE_int32(num_bins, 80, "num mel bins for fbank feature");
DEFINE_int32(sample_rate, 16000, "sample rate for audio");
// TLG fst
DEFINE_string(fst_path, "", "TLG fst path");
// DecodeOptions flags
DEFINE_int32(chunk_size, 16, "decoding chunk size");
DEFINE_int32(num_left_chunks, -1, "left chunks in decoding");
DEFINE_double(ctc_weight, 0.5,
"ctc weight when combining ctc score and rescoring score");
DEFINE_double(rescoring_weight, 1.0,
"rescoring weight when combining ctc score and rescoring score");
DEFINE_double(reverse_weight, 0.0,
"used for bitransformer rescoring. it must be 0.0 if decoder is"
"conventional transformer decoder, and only reverse_weight > 0.0"
"dose the right to left decoder will be calculated and used");
DEFINE_int32(max_active, 7000, "max active states in ctc wfst search");
DEFINE_int32(min_active, 200, "min active states in ctc wfst search");
DEFINE_double(beam, 16.0, "beam in ctc wfst search");
DEFINE_double(lattice_beam, 10.0, "lattice beam in ctc wfst search");
DEFINE_double(acoustic_scale, 1.0, "acoustic scale for ctc wfst search");
DEFINE_double(blank_skip_thresh, 1.0,
"blank skip thresh for ctc wfst search, 1.0 means no skip");
DEFINE_double(blank_scale, 1.0, "blank scale for ctc wfst search");
DEFINE_double(length_penalty, 0.0,
"length penalty ctc wfst search, will not"
"apply on self-loop arc, for balancing the del/ins ratio, "
"suggest set to -3.0");
DEFINE_int32(nbest, 10, "nbest for ctc wfst or prefix search");
// SymbolTable flags
DEFINE_string(dict_path, "",
"dict symbol table path, required when LM is enabled");
DEFINE_string(unit_path, "",
"e2e model unit symbol table, it is used in both "
"with/without LM scenarios for context/timestamp");
// Context flags
DEFINE_string(context_path, "", "context path, is used to build context graph");
DEFINE_double(context_score, 3.0, "is used to rescore the decoded result");
// PostProcessOptions flags
DEFINE_int32(language_type, 0,
"remove spaces according to language type"
"0x00 = kMandarinEnglish, "
"0x01 = kIndoEuropean");
DEFINE_bool(lowercase, true, "lowercase final result if needed");
namespace wenet {
std::shared_ptr<FeaturePipelineConfig> InitFeaturePipelineConfigFromFlags() {
auto feature_config = std::make_shared<FeaturePipelineConfig>(
FLAGS_num_bins, FLAGS_sample_rate);
return feature_config;
}
std::shared_ptr<DecodeOptions> InitDecodeOptionsFromFlags() {
auto decode_config = std::make_shared<DecodeOptions>();
decode_config->chunk_size = FLAGS_chunk_size;
decode_config->num_left_chunks = FLAGS_num_left_chunks;
decode_config->ctc_weight = FLAGS_ctc_weight;
decode_config->reverse_weight = FLAGS_reverse_weight;
decode_config->rescoring_weight = FLAGS_rescoring_weight;
decode_config->ctc_wfst_search_opts.max_active = FLAGS_max_active;
decode_config->ctc_wfst_search_opts.min_active = FLAGS_min_active;
decode_config->ctc_wfst_search_opts.beam = FLAGS_beam;
decode_config->ctc_wfst_search_opts.lattice_beam = FLAGS_lattice_beam;
decode_config->ctc_wfst_search_opts.acoustic_scale = FLAGS_acoustic_scale;
decode_config->ctc_wfst_search_opts.blank_skip_thresh =
FLAGS_blank_skip_thresh;
decode_config->ctc_wfst_search_opts.blank_scale = FLAGS_blank_scale;
decode_config->ctc_wfst_search_opts.length_penalty = FLAGS_length_penalty;
decode_config->ctc_wfst_search_opts.nbest = FLAGS_nbest;
decode_config->ctc_prefix_search_opts.first_beam_size = FLAGS_nbest;
decode_config->ctc_prefix_search_opts.second_beam_size = FLAGS_nbest;
return decode_config;
}
std::shared_ptr<DecodeResource> InitDecodeResourceFromFlags() {
auto resource = std::make_shared<DecodeResource>();
const int kNumGemmThreads = 1;
if (!FLAGS_onnx_dir.empty()) {
#ifdef USE_ONNX
LOG(INFO) << "Reading onnx model ";
OnnxAsrModel::InitEngineThreads(kNumGemmThreads);
auto model = std::make_shared<OnnxAsrModel>();
model->Read(FLAGS_onnx_dir);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DONNX=ON'.";
#endif
} else if (!FLAGS_model_path.empty()) {
#ifdef USE_TORCH
LOG(INFO) << "Reading torch model " << FLAGS_model_path;
TorchAsrModel::InitEngineThreads(kNumGemmThreads);
auto model = std::make_shared<TorchAsrModel>();
model->Read(FLAGS_model_path);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DTORCH=ON'.";
#endif
} else if (!FLAGS_xpu_model_dir.empty()) {
#ifdef USE_XPU
LOG(INFO) << "Reading XPU WeNet model weight from " << FLAGS_xpu_model_dir;
auto model = std::make_shared<XPUAsrModel>();
model->SetEngineThreads(kNumGemmThreads);
model->SetDeviceId(FLAGS_device_id);
model->Read(FLAGS_xpu_model_dir);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DXPU=ON'.";
#endif
} else if (!FLAGS_bpu_model_dir.empty()) {
#ifdef USE_BPU
LOG(INFO) << "Reading Horizon BPU model from " << FLAGS_bpu_model_dir;
auto model = std::make_shared<BPUAsrModel>();
model->Read(FLAGS_bpu_model_dir);
resource->model = model;
#else
LOG(FATAL) << "Please rebuild with cmake options '-DBPU=ON'.";
#endif
} else {
LOG(FATAL) << "Please set ONNX, TORCH, XPU or BPU model path!!!";
}
LOG(INFO) << "Reading unit table " << FLAGS_unit_path;
auto unit_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(FLAGS_unit_path));
CHECK(unit_table != nullptr);
resource->unit_table = unit_table;
if (!FLAGS_fst_path.empty()) { // With LM
CHECK(!FLAGS_dict_path.empty());
LOG(INFO) << "Reading fst " << FLAGS_fst_path;
auto fst = std::shared_ptr<fst::Fst<fst::StdArc>>(
fst::Fst<fst::StdArc>::Read(FLAGS_fst_path));
CHECK(fst != nullptr);
resource->fst = fst;
LOG(INFO) << "Reading symbol table " << FLAGS_dict_path;
auto symbol_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(FLAGS_dict_path));
CHECK(symbol_table != nullptr);
resource->symbol_table = symbol_table;
} else { // Without LM, symbol_table is the same as unit_table
resource->symbol_table = unit_table;
}
if (!FLAGS_context_path.empty()) {
LOG(INFO) << "Reading context " << FLAGS_context_path;
std::vector<std::string> contexts;
std::ifstream infile(FLAGS_context_path);
std::string context;
while (getline(infile, context)) {
contexts.emplace_back(Trim(context));
}
ContextConfig config;
config.context_score = FLAGS_context_score;
resource->context_graph = std::make_shared<ContextGraph>(config);
resource->context_graph->BuildContextGraph(contexts,
resource->symbol_table);
}
PostProcessOptions post_process_opts;
post_process_opts.language_type =
FLAGS_language_type == 0 ? kMandarinEnglish : kIndoEuropean;
post_process_opts.lowercase = FLAGS_lowercase;
resource->post_processor =
std::make_shared<PostProcessor>(std::move(post_process_opts));
return resource;
}
} // namespace wenet
#endif // DECODER_PARAMS_H_
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_SEARCH_INTERFACE_H_
#define DECODER_SEARCH_INTERFACE_H_
namespace wenet {
#include <vector>
enum SearchType {
kPrefixBeamSearch = 0x00,
kWfstBeamSearch = 0x01,
};
class SearchInterface {
public:
virtual ~SearchInterface() {}
virtual void Search(const std::vector<std::vector<float>>& logp) = 0;
virtual void Reset() = 0;
virtual void FinalizeSearch() = 0;
virtual SearchType Type() const = 0;
// N-best inputs id
virtual const std::vector<std::vector<int>>& Inputs() const = 0;
// N-best outputs id
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
// N-best likelihood
virtual const std::vector<float>& Likelihood() const = 0;
// N-best timestamp
virtual const std::vector<std::vector<int>>& Times() const = 0;
};
} // namespace wenet
#endif // DECODER_SEARCH_INTERFACE_H_
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/torch_asr_model.h"
#include <algorithm>
#include <memory>
#include <stdexcept>
#include <utility>
#include "torch/script.h"
#ifndef IOS
#include "torch/torch.h"
#endif
namespace wenet {
#ifndef IOS
void TorchAsrModel::InitEngineThreads(int num_threads) {
// For multi-thread performance
at::set_num_threads(num_threads);
VLOG(1) << "Num intra-op threads: " << at::get_num_threads();
}
#endif
void TorchAsrModel::Read(const std::string& model_path) {
torch::DeviceType device = at::kCPU;
#ifdef USE_GPU
if (!torch::cuda::is_available()) {
VLOG(1) << "CUDA is not available! Please check your GPU settings";
throw std::runtime_error("CUDA is not available!");
} else {
VLOG(1) << "CUDA available! Running on GPU";
device = at::kCUDA;
}
#endif
torch::jit::script::Module model = torch::jit::load(model_path, device);
model_ = std::make_shared<TorchModule>(std::move(model));
torch::NoGradGuard no_grad;
model_->eval();
torch::jit::IValue o1 = model_->run_method("subsampling_rate");
CHECK_EQ(o1.isInt(), true);
subsampling_rate_ = o1.toInt();
torch::jit::IValue o2 = model_->run_method("right_context");
CHECK_EQ(o2.isInt(), true);
right_context_ = o2.toInt();
torch::jit::IValue o3 = model_->run_method("sos_symbol");
CHECK_EQ(o3.isInt(), true);
sos_ = o3.toInt();
torch::jit::IValue o4 = model_->run_method("eos_symbol");
CHECK_EQ(o4.isInt(), true);
eos_ = o4.toInt();
torch::jit::IValue o5 = model_->run_method("is_bidirectional_decoder");
CHECK_EQ(o5.isBool(), true);
is_bidirectional_decoder_ = o5.toBool();
VLOG(1) << "Torch Model Info:";
VLOG(1) << "\tsubsampling_rate " << subsampling_rate_;
VLOG(1) << "\tright context " << right_context_;
VLOG(1) << "\tsos " << sos_;
VLOG(1) << "\teos " << eos_;
VLOG(1) << "\tis bidirectional decoder " << is_bidirectional_decoder_;
}
TorchAsrModel::TorchAsrModel(const TorchAsrModel& other) {
// 1. Init the model info
right_context_ = other.right_context_;
subsampling_rate_ = other.subsampling_rate_;
sos_ = other.sos_;
eos_ = other.eos_;
is_bidirectional_decoder_ = other.is_bidirectional_decoder_;
chunk_size_ = other.chunk_size_;
num_left_chunks_ = other.num_left_chunks_;
offset_ = other.offset_;
// 2. Model copy, just copy the model ptr since:
// PyTorch allows using multiple CPU threads during TorchScript model
// inference, please see https://pytorch.org/docs/stable/notes/cpu_
// threading_torchscript_inference.html
model_ = other.model_;
// NOTE(Binbin Zhang):
// inner states for forward are not copied here.
}
std::shared_ptr<AsrModel> TorchAsrModel::Copy() const {
auto asr_model = std::make_shared<TorchAsrModel>(*this);
// Reset the inner states for new decoding
asr_model->Reset();
return asr_model;
}
void TorchAsrModel::Reset() {
offset_ = 0;
att_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
cnn_cache_ = std::move(torch::zeros({0, 0, 0, 0}));
encoder_outs_.clear();
cached_feature_.clear();
}
void TorchAsrModel::ForwardEncoderFunc(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* out_prob) {
// 1. Prepare libtorch required data, splice cached_feature_ and chunk_feats
// The first dimension is for batchsize, which is 1.
int num_frames = cached_feature_.size() + chunk_feats.size();
const int feature_dim = chunk_feats[0].size();
torch::Tensor feats =
torch::zeros({1, num_frames, feature_dim}, torch::kFloat);
for (size_t i = 0; i < cached_feature_.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(cached_feature_[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][i] = std::move(row);
}
for (size_t i = 0; i < chunk_feats.size(); ++i) {
torch::Tensor row =
torch::from_blob(const_cast<float*>(chunk_feats[i].data()),
{feature_dim}, torch::kFloat)
.clone();
feats[0][cached_feature_.size() + i] = std::move(row);
}
// 2. Encoder chunk forward
#ifdef USE_GPU
feats = feats.to(at::kCUDA);
att_cache_ = att_cache_.to(at::kCUDA);
cnn_cache_ = cnn_cache_.to(at::kCUDA);
#endif
int required_cache_size = chunk_size_ * num_left_chunks_;
torch::NoGradGuard no_grad;
std::vector<torch::jit::IValue> inputs = {feats, offset_, required_cache_size,
att_cache_, cnn_cache_};
// Refer interfaces in wenet/transformer/asr_model.py
auto outputs =
model_->get_method("forward_encoder_chunk")(inputs).toTuple()->elements();
CHECK_EQ(outputs.size(), 3);
#ifdef USE_GPU
torch::Tensor chunk_out = outputs[0].toTensor().to(at::kCPU);
att_cache_ = outputs[1].toTensor().to(at::kCPU);
cnn_cache_ = outputs[2].toTensor().to(at::kCPU);
#else
torch::Tensor chunk_out = outputs[0].toTensor();
att_cache_ = outputs[1].toTensor();
cnn_cache_ = outputs[2].toTensor();
#endif
offset_ += chunk_out.size(1);
// The first dimension of returned value is for batchsize, which is 1
#ifdef USE_GPU
chunk_out = chunk_out.to(at::kCUDA);
torch::Tensor ctc_log_probs =
model_->run_method("ctc_activation", chunk_out).toTensor();
ctc_log_probs = ctc_log_probs.to(at::kCPU)[0];
encoder_outs_.push_back(std::move(chunk_out.to(at::kCPU)));
#else
torch::Tensor ctc_log_probs =
model_->run_method("ctc_activation", chunk_out).toTensor()[0];
encoder_outs_.push_back(std::move(chunk_out));
#endif
// Copy to output
int num_outputs = ctc_log_probs.size(0);
int output_dim = ctc_log_probs.size(1);
out_prob->resize(num_outputs);
for (int i = 0; i < num_outputs; i++) {
(*out_prob)[i].resize(output_dim);
memcpy((*out_prob)[i].data(), ctc_log_probs[i].data_ptr(),
sizeof(float) * output_dim);
}
}
float TorchAsrModel::ComputeAttentionScore(const torch::Tensor& prob,
const std::vector<int>& hyp,
int eos) {
float score = 0.0f;
auto accessor = prob.accessor<float, 2>();
for (size_t j = 0; j < hyp.size(); ++j) {
score += accessor[j][hyp[j]];
}
score += accessor[hyp.size()][eos];
return score;
}
void TorchAsrModel::AttentionRescoring(
const std::vector<std::vector<int>>& hyps, float reverse_weight,
std::vector<float>* rescoring_score) {
CHECK(rescoring_score != nullptr);
int num_hyps = hyps.size();
rescoring_score->resize(num_hyps, 0.0f);
if (num_hyps == 0) {
return;
}
// No encoder output
if (encoder_outs_.size() == 0) {
return;
}
torch::NoGradGuard no_grad;
// Step 1: Prepare input for libtorch
torch::Tensor hyps_length = torch::zeros({num_hyps}, torch::kLong);
int max_hyps_len = 0;
for (size_t i = 0; i < num_hyps; ++i) {
int length = hyps[i].size() + 1;
max_hyps_len = std::max(length, max_hyps_len);
hyps_length[i] = static_cast<int64_t>(length);
}
torch::Tensor hyps_tensor =
torch::zeros({num_hyps, max_hyps_len}, torch::kLong);
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
hyps_tensor[i][0] = sos_;
for (size_t j = 0; j < hyp.size(); ++j) {
hyps_tensor[i][j + 1] = hyp[j];
}
}
// Step 2: Forward attention decoder by hyps and corresponding encoder_outs_
torch::Tensor encoder_out = torch::cat(encoder_outs_, 1);
#ifdef USE_GPU
hyps_tensor = hyps_tensor.to(at::kCUDA);
hyps_length = hyps_length.to(at::kCUDA);
encoder_out = encoder_out.to(at::kCUDA);
#endif
auto outputs = model_
->run_method("forward_attention_decoder", hyps_tensor,
hyps_length, encoder_out, reverse_weight)
.toTuple()
->elements();
#ifdef USE_GPU
auto probs = outputs[0].toTensor().to(at::kCPU);
auto r_probs = outputs[1].toTensor().to(at::kCPU);
#else
auto probs = outputs[0].toTensor();
auto r_probs = outputs[1].toTensor();
#endif
CHECK_EQ(probs.size(0), num_hyps);
CHECK_EQ(probs.size(1), max_hyps_len);
// Step 3: Compute rescoring score
for (size_t i = 0; i < num_hyps; ++i) {
const std::vector<int>& hyp = hyps[i];
float score = 0.0f;
// left-to-right decoder score
score = ComputeAttentionScore(probs[i], hyp, eos_);
// Optional: Used for right to left score
float r_score = 0.0f;
if (is_bidirectional_decoder_ && reverse_weight > 0) {
// right-to-left score
CHECK_EQ(r_probs.size(0), num_hyps);
CHECK_EQ(r_probs.size(1), max_hyps_len);
std::vector<int> r_hyp(hyp.size());
std::reverse_copy(hyp.begin(), hyp.end(), r_hyp.begin());
// right to left decoder score
r_score = ComputeAttentionScore(r_probs[i], r_hyp, eos_);
}
// combined left-to-right and right-to-left score
(*rescoring_score)[i] =
score * (1 - reverse_weight) + r_score * reverse_weight;
}
}
} // namespace wenet
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_TORCH_ASR_MODEL_H_
#define DECODER_TORCH_ASR_MODEL_H_
#include <memory>
#include <string>
#include <vector>
#include "torch/script.h"
#ifndef IOS
#include "torch/torch.h"
#endif
#include "decoder/asr_model.h"
#include "utils/utils.h"
namespace wenet {
class TorchAsrModel : public AsrModel {
public:
#ifndef IOS
static void InitEngineThreads(int num_threads = 1);
#endif
public:
using TorchModule = torch::jit::script::Module;
TorchAsrModel() = default;
TorchAsrModel(const TorchAsrModel& other);
void Read(const std::string& model_path);
std::shared_ptr<TorchModule> torch_model() const { return model_; }
void Reset() override;
void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) override;
std::shared_ptr<AsrModel> Copy() const override;
protected:
void ForwardEncoderFunc(const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob) override;
float ComputeAttentionScore(const torch::Tensor& prob,
const std::vector<int>& hyp, int eos);
private:
std::shared_ptr<TorchModule> model_ = nullptr;
std::vector<torch::Tensor> encoder_outs_;
// transformer/conformer attention cache
torch::Tensor att_cache_ = torch::zeros({0, 0, 0, 0});
// conformer-only conv_module cache
torch::Tensor cnn_cache_ = torch::zeros({0, 0, 0, 0});
};
} // namespace wenet
#endif // DECODER_TORCH_ASR_MODEL_H_
add_library(frontend STATIC
feature_pipeline.cc
fft.cc
)
target_link_libraries(frontend PUBLIC utils)
\ No newline at end of file
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_FBANK_H_
#define FRONTEND_FBANK_H_
#include <cstring>
#include <limits>
#include <random>
#include <utility>
#include <vector>
#include "frontend/fft.h"
#include "utils/log.h"
namespace wenet {
// This code is based on kaldi Fbank implementation, please see
// https://github.com/kaldi-asr/kaldi/blob/master/src/feat/feature-fbank.cc
class Fbank {
public:
Fbank(int num_bins, int sample_rate, int frame_length, int frame_shift)
: num_bins_(num_bins),
sample_rate_(sample_rate),
frame_length_(frame_length),
frame_shift_(frame_shift),
use_log_(true),
remove_dc_offset_(true),
generator_(0),
distribution_(0, 1.0),
dither_(0.0) {
fft_points_ = UpperPowerOfTwo(frame_length_);
// generate bit reversal table and trigonometric function table
const int fft_points_4 = fft_points_ / 4;
bitrev_.resize(fft_points_);
sintbl_.resize(fft_points_ + fft_points_4);
make_sintbl(fft_points_, sintbl_.data());
make_bitrev(fft_points_, bitrev_.data());
int num_fft_bins = fft_points_ / 2;
float fft_bin_width = static_cast<float>(sample_rate_) / fft_points_;
int low_freq = 20, high_freq = sample_rate_ / 2;
float mel_low_freq = MelScale(low_freq);
float mel_high_freq = MelScale(high_freq);
float mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1);
bins_.resize(num_bins_);
center_freqs_.resize(num_bins_);
for (int bin = 0; bin < num_bins; ++bin) {
float left_mel = mel_low_freq + bin * mel_freq_delta,
center_mel = mel_low_freq + (bin + 1) * mel_freq_delta,
right_mel = mel_low_freq + (bin + 2) * mel_freq_delta;
center_freqs_[bin] = InverseMelScale(center_mel);
std::vector<float> this_bin(num_fft_bins);
int first_index = -1, last_index = -1;
for (int i = 0; i < num_fft_bins; ++i) {
float freq = (fft_bin_width * i); // Center frequency of this fft
// bin.
float mel = MelScale(freq);
if (mel > left_mel && mel < right_mel) {
float weight;
if (mel <= center_mel)
weight = (mel - left_mel) / (center_mel - left_mel);
else
weight = (right_mel - mel) / (right_mel - center_mel);
this_bin[i] = weight;
if (first_index == -1) first_index = i;
last_index = i;
}
}
CHECK(first_index != -1 && last_index >= first_index);
bins_[bin].first = first_index;
int size = last_index + 1 - first_index;
bins_[bin].second.resize(size);
for (int i = 0; i < size; ++i) {
bins_[bin].second[i] = this_bin[first_index + i];
}
}
// povey window
povey_window_.resize(frame_length_);
double a = M_2PI / (frame_length - 1);
for (int i = 0; i < frame_length; ++i) {
povey_window_[i] = pow(0.5 - 0.5 * cos(a * i), 0.85);
}
}
void set_use_log(bool use_log) { use_log_ = use_log; }
void set_remove_dc_offset(bool remove_dc_offset) {
remove_dc_offset_ = remove_dc_offset;
}
void set_dither(float dither) { dither_ = dither; }
int num_bins() const { return num_bins_; }
static inline float InverseMelScale(float mel_freq) {
return 700.0f * (expf(mel_freq / 1127.0f) - 1.0f);
}
static inline float MelScale(float freq) {
return 1127.0f * logf(1.0f + freq / 700.0f);
}
static int UpperPowerOfTwo(int n) {
return static_cast<int>(pow(2, ceil(log(n) / log(2))));
}
// pre emphasis
void PreEmphasis(float coeff, std::vector<float>* data) const {
if (coeff == 0.0) return;
for (int i = data->size() - 1; i > 0; i--)
(*data)[i] -= coeff * (*data)[i - 1];
(*data)[0] -= coeff * (*data)[0];
}
// Apply povey window on data in place
void Povey(std::vector<float>* data) const {
CHECK_GE(data->size(), povey_window_.size());
for (size_t i = 0; i < povey_window_.size(); ++i) {
(*data)[i] *= povey_window_[i];
}
}
// Compute fbank feat, return num frames
int Compute(const std::vector<float>& wave,
std::vector<std::vector<float>>* feat) {
int num_samples = wave.size();
if (num_samples < frame_length_) return 0;
int num_frames = 1 + ((num_samples - frame_length_) / frame_shift_);
feat->resize(num_frames);
std::vector<float> fft_real(fft_points_, 0), fft_img(fft_points_, 0);
std::vector<float> power(fft_points_ / 2);
for (int i = 0; i < num_frames; ++i) {
std::vector<float> data(wave.data() + i * frame_shift_,
wave.data() + i * frame_shift_ + frame_length_);
// optional add noise
if (dither_ != 0.0) {
for (size_t j = 0; j < data.size(); ++j)
data[j] += dither_ * distribution_(generator_);
}
// optinal remove dc offset
if (remove_dc_offset_) {
float mean = 0.0;
for (size_t j = 0; j < data.size(); ++j) mean += data[j];
mean /= data.size();
for (size_t j = 0; j < data.size(); ++j) data[j] -= mean;
}
PreEmphasis(0.97, &data);
Povey(&data);
// copy data to fft_real
memset(fft_img.data(), 0, sizeof(float) * fft_points_);
memset(fft_real.data() + frame_length_, 0,
sizeof(float) * (fft_points_ - frame_length_));
memcpy(fft_real.data(), data.data(), sizeof(float) * frame_length_);
fft(bitrev_.data(), sintbl_.data(), fft_real.data(), fft_img.data(),
fft_points_);
// power
for (int j = 0; j < fft_points_ / 2; ++j) {
power[j] = fft_real[j] * fft_real[j] + fft_img[j] * fft_img[j];
}
(*feat)[i].resize(num_bins_);
// cepstral coefficients, triangle filter array
for (int j = 0; j < num_bins_; ++j) {
float mel_energy = 0.0;
int s = bins_[j].first;
for (size_t k = 0; k < bins_[j].second.size(); ++k) {
mel_energy += bins_[j].second[k] * power[s + k];
}
// optional use log
if (use_log_) {
if (mel_energy < std::numeric_limits<float>::epsilon())
mel_energy = std::numeric_limits<float>::epsilon();
mel_energy = logf(mel_energy);
}
(*feat)[i][j] = mel_energy;
}
}
return num_frames;
}
private:
int num_bins_;
int sample_rate_;
int frame_length_, frame_shift_;
int fft_points_;
bool use_log_;
bool remove_dc_offset_;
std::vector<float> center_freqs_;
std::vector<std::pair<int, std::vector<float>>> bins_;
std::vector<float> povey_window_;
std::default_random_engine generator_;
std::normal_distribution<float> distribution_;
float dither_;
// bit reversal table
std::vector<int> bitrev_;
// trigonometric function table
std::vector<float> sintbl_;
};
} // namespace wenet
#endif // FRONTEND_FBANK_H_
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/feature_pipeline.h"
#include <algorithm>
#include <utility>
namespace wenet {
FeaturePipeline::FeaturePipeline(const FeaturePipelineConfig& config)
: config_(config),
feature_dim_(config.num_bins),
fbank_(config.num_bins, config.sample_rate, config.frame_length,
config.frame_shift),
num_frames_(0),
input_finished_(false) {}
void FeaturePipeline::AcceptWaveform(const float* pcm, const int size) {
std::vector<std::vector<float>> feats;
std::vector<float> waves;
waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end());
waves.insert(waves.end(), pcm, pcm + size);
int num_frames = fbank_.Compute(waves, &feats);
feature_queue_.Push(std::move(feats));
num_frames_ += num_frames;
int left_samples = waves.size() - config_.frame_shift * num_frames;
remained_wav_.resize(left_samples);
std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(),
remained_wav_.begin());
// We are still adding wave, notify input is not finished
finish_condition_.notify_one();
}
void FeaturePipeline::AcceptWaveform(const int16_t* pcm, const int size) {
auto* float_pcm = new float[size];
for (size_t i = 0; i < size; i++) {
float_pcm[i] = static_cast<float>(pcm[i]);
}
this->AcceptWaveform(float_pcm, size);
delete[] float_pcm;
}
void FeaturePipeline::set_input_finished() {
CHECK(!input_finished_);
{
std::lock_guard<std::mutex> lock(mutex_);
input_finished_ = true;
}
finish_condition_.notify_one();
}
bool FeaturePipeline::ReadOne(std::vector<float>* feat) {
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
} else {
std::unique_lock<std::mutex> lock(mutex_);
while (!input_finished_) {
// This will release the lock and wait for notify_one()
// from AcceptWaveform() or set_input_finished()
finish_condition_.wait(lock);
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
}
}
CHECK(input_finished_);
// Double check queue.empty, see issue#893 for detailed discussions.
if (!feature_queue_.Empty()) {
*feat = std::move(feature_queue_.Pop());
return true;
} else {
return false;
}
}
}
bool FeaturePipeline::Read(int num_frames,
std::vector<std::vector<float>>* feats) {
feats->clear();
if (feature_queue_.Size() >= num_frames) {
*feats = std::move(feature_queue_.Pop(num_frames));
return true;
} else {
std::unique_lock<std::mutex> lock(mutex_);
while (!input_finished_) {
// This will release the lock and wait for notify_one()
// from AcceptWaveform() or set_input_finished()
finish_condition_.wait(lock);
if (feature_queue_.Size() >= num_frames) {
*feats = std::move(feature_queue_.Pop(num_frames));
return true;
}
}
CHECK(input_finished_);
// Double check queue.empty, see issue#893 for detailed discussions.
if (feature_queue_.Size() >= num_frames) {
*feats = std::move(feature_queue_.Pop(num_frames));
return true;
} else {
*feats = std::move(feature_queue_.Pop(feature_queue_.Size()));
return false;
}
}
}
void FeaturePipeline::Reset() {
input_finished_ = false;
num_frames_ = 0;
remained_wav_.clear();
feature_queue_.Clear();
}
} // namespace wenet
// Copyright (c) 2017 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_FEATURE_PIPELINE_H_
#define FRONTEND_FEATURE_PIPELINE_H_
#include <mutex>
#include <queue>
#include <string>
#include <vector>
#include "frontend/fbank.h"
#include "utils/blocking_queue.h"
#include "utils/log.h"
namespace wenet {
struct FeaturePipelineConfig {
int num_bins;
int sample_rate;
int frame_length;
int frame_shift;
FeaturePipelineConfig(int num_bins, int sample_rate)
: num_bins(num_bins), // 80 dim fbank
sample_rate(sample_rate) { // 16k sample rate
frame_length = sample_rate / 1000 * 25; // frame length 25ms
frame_shift = sample_rate / 1000 * 10; // frame shift 10ms
}
void Info() const {
LOG(INFO) << "feature pipeline config"
<< " num_bins " << num_bins << " frame_length " << frame_length
<< " frame_shift " << frame_shift;
}
};
// Typically, FeaturePipeline is used in two threads: one thread A calls
// AcceptWaveform() to add raw wav data and set_input_finished() to notice
// the end of input wav, another thread B (decoder thread) calls Read() to
// consume features.So a BlockingQueue is used to make this class thread safe.
// The Read() is designed as a blocking method when there is no feature
// in feature_queue_ and the input is not finished.
// See bin/decoder_main.cc, websocket/websocket_server.cc and
// decoder/torch_asr_decoder.cc for usage
class FeaturePipeline {
public:
explicit FeaturePipeline(const FeaturePipelineConfig& config);
// The feature extraction is done in AcceptWaveform().
void AcceptWaveform(const float* pcm, const int size);
void AcceptWaveform(const int16_t* pcm, const int size);
// Current extracted frames number.
int num_frames() const { return num_frames_; }
int feature_dim() const { return feature_dim_; }
const FeaturePipelineConfig& config() const { return config_; }
// The caller should call this method when speech input is end.
// Never call AcceptWaveform() after calling set_input_finished() !
void set_input_finished();
bool input_finished() const { return input_finished_; }
// Return False if input is finished and no feature could be read.
// Return True if a feature is read.
// This function is a blocking method. It will block the thread when
// there is no feature in feature_queue_ and the input is not finished.
bool ReadOne(std::vector<float>* feat);
// Read #num_frames frame features.
// Return False if less than #num_frames features are read and the
// input is finished.
// Return True if #num_frames features are read.
// This function is a blocking method when there is no feature
// in feature_queue_ and the input is not finished.
bool Read(int num_frames, std::vector<std::vector<float>>* feats);
void Reset();
bool IsLastFrame(int frame) const {
return input_finished_ && (frame == num_frames_ - 1);
}
int NumQueuedFrames() const { return feature_queue_.Size(); }
private:
const FeaturePipelineConfig& config_;
int feature_dim_;
Fbank fbank_;
BlockingQueue<std::vector<float>> feature_queue_;
int num_frames_;
bool input_finished_;
// The feature extraction is done in AcceptWaveform().
// This waveform sample points are consumed by frame size.
// The residual waveform sample points after framing are
// kept to be used in next AcceptWaveform() calling.
std::vector<float> remained_wav_;
// Used to block the Read when there is no feature in feature_queue_
// and the input is not finished.
mutable std::mutex mutex_;
std::condition_variable finish_condition_;
};
} // namespace wenet
#endif // FRONTEND_FEATURE_PIPELINE_H_
// Copyright (c) 2016 Network
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include "frontend/fft.h"
namespace wenet {
void make_sintbl(int n, float* sintbl) {
int i, n2, n4, n8;
float c, s, dc, ds, t;
n2 = n / 2;
n4 = n / 4;
n8 = n / 8;
t = sin(M_PI / n);
dc = 2 * t * t;
ds = sqrt(dc * (2 - dc));
t = 2 * dc;
c = sintbl[n4] = 1;
s = sintbl[0] = 0;
for (i = 1; i < n8; ++i) {
c -= dc;
dc += t * c;
s += ds;
ds -= t * s;
sintbl[i] = s;
sintbl[n4 - i] = c;
}
if (n8 != 0) sintbl[n8] = sqrt(0.5);
for (i = 0; i < n4; ++i) sintbl[n2 - i] = sintbl[i];
for (i = 0; i < n2 + n4; ++i) sintbl[i + n2] = -sintbl[i];
}
void make_bitrev(int n, int* bitrev) {
int i, j, k, n2;
n2 = n / 2;
i = j = 0;
for (;;) {
bitrev[i] = j;
if (++i >= n) break;
k = n2;
while (k <= j) {
j -= k;
k /= 2;
}
j += k;
}
}
// bitrev: bit reversal table
// sintbl: trigonometric function table
// x:real part
// y:image part
// n: fft length
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n) {
int i, j, k, ik, h, d, k2, n4, inverse;
float t, s, c, dx, dy;
/* preparation */
if (n < 0) {
n = -n;
inverse = 1; /* inverse transform */
} else {
inverse = 0;
}
n4 = n / 4;
if (n == 0) {
return 0;
}
/* bit reversal */
for (i = 0; i < n; ++i) {
j = bitrev[i];
if (i < j) {
t = x[i];
x[i] = x[j];
x[j] = t;
t = y[i];
y[i] = y[j];
y[j] = t;
}
}
/* transformation */
for (k = 1; k < n; k = k2) {
h = 0;
k2 = k + k;
d = n / k2;
for (j = 0; j < k; ++j) {
c = sintbl[h + n4];
if (inverse)
s = -sintbl[h];
else
s = sintbl[h];
for (i = j; i < n; i += k2) {
ik = i + k;
dx = s * y[ik] + c * x[ik];
dy = c * y[ik] - s * x[ik];
x[ik] = x[i] - dx;
x[i] += dx;
y[ik] = y[i] - dy;
y[i] += dy;
}
h += d;
}
}
if (inverse) {
/* divide by n in case of the inverse transformation */
for (i = 0; i < n; ++i) {
x[i] /= n;
y[i] /= n;
}
}
return 0; /* finished successfully */
}
} // namespace wenet
// Copyright (c) 2016 Network
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_FFT_H_
#define FRONTEND_FFT_H_
#ifndef M_PI
#define M_PI 3.1415926535897932384626433832795
#endif
#ifndef M_2PI
#define M_2PI 6.283185307179586476925286766559005
#endif
namespace wenet {
// Fast Fourier Transform
void make_sintbl(int n, float* sintbl);
void make_bitrev(int n, int* bitrev);
int fft(const int* bitrev, const float* sintbl, float* x, float* y, int n);
} // namespace wenet
#endif // FRONTEND_FFT_H_
// Copyright (c) 2016 Personal (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FRONTEND_WAV_H_
#define FRONTEND_WAV_H_
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <string>
#include "utils/log.h"
namespace wenet {
struct WavHeader {
char riff[4] = {'R', 'I', 'F', 'F'};
unsigned int size = 0;
char wav[4] = {'W', 'A', 'V', 'E'};
char fmt[4] = {'f', 'm', 't', ' '};
unsigned int fmt_size = 16;
uint16_t format = 1;
uint16_t channels = 0;
unsigned int sample_rate = 0;
unsigned int bytes_per_second = 0;
uint16_t block_size = 0;
uint16_t bit = 0;
char data[4] = {'d', 'a', 't', 'a'};
unsigned int data_size = 0;
WavHeader() {}
WavHeader(int num_samples, int num_channel, int sample_rate,
int bits_per_sample) {
data_size = num_samples * num_channel * (bits_per_sample / 8);
size = sizeof(WavHeader) - 8 + data_size;
channels = num_channel;
this->sample_rate = sample_rate;
bytes_per_second = sample_rate * num_channel * (bits_per_sample / 8);
block_size = num_channel * (bits_per_sample / 8);
bit = bits_per_sample;
}
};
class WavReader {
public:
WavReader() : data_(nullptr) {}
explicit WavReader(const std::string& filename) { Open(filename); }
bool Open(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "rb");
if (NULL == fp) {
LOG(WARNING) << "Error in read " << filename;
return false;
}
WavHeader header;
fread(&header, 1, sizeof(header), fp);
if (header.fmt_size < 16) {
fprintf(stderr,
"WaveData: expect PCM format data "
"to have fmt chunk of at least size 16.\n");
return false;
} else if (header.fmt_size > 16) {
int offset = 44 - 8 + header.fmt_size - 16;
fseek(fp, offset, SEEK_SET);
fread(header.data, 8, sizeof(char), fp);
}
// check "RIFF" "WAVE" "fmt " "data"
// Skip any sub-chunks between "fmt" and "data". Usually there will
// be a single "fact" sub chunk, but on Windows there can also be a
// "list" sub chunk.
while (0 != strncmp(header.data, "data", 4)) {
// We will just ignore the data in these chunks.
fseek(fp, header.data_size, SEEK_CUR);
// read next sub chunk
fread(header.data, 8, sizeof(char), fp);
}
num_channel_ = header.channels;
sample_rate_ = header.sample_rate;
bits_per_sample_ = header.bit;
int num_data = header.data_size / (bits_per_sample_ / 8);
data_ = new float[num_data];
num_samples_ = num_data / num_channel_;
for (int i = 0; i < num_data; ++i) {
switch (bits_per_sample_) {
case 8: {
char sample;
fread(&sample, 1, sizeof(char), fp);
data_[i] = static_cast<float>(sample);
break;
}
case 16: {
int16_t sample;
fread(&sample, 1, sizeof(int16_t), fp);
data_[i] = static_cast<float>(sample);
break;
}
case 32: {
int sample;
fread(&sample, 1, sizeof(int), fp);
data_[i] = static_cast<float>(sample);
break;
}
default:
fprintf(stderr, "unsupported quantization bits");
exit(1);
}
}
fclose(fp);
return true;
}
int num_channel() const { return num_channel_; }
int sample_rate() const { return sample_rate_; }
int bits_per_sample() const { return bits_per_sample_; }
int num_samples() const { return num_samples_; }
~WavReader() {
delete[] data_;
}
const float* data() const { return data_; }
private:
int num_channel_;
int sample_rate_;
int bits_per_sample_;
int num_samples_; // sample points per channel
float* data_;
};
class WavWriter {
public:
WavWriter(const float* data, int num_samples, int num_channel,
int sample_rate, int bits_per_sample)
: data_(data),
num_samples_(num_samples),
num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample) {}
void Write(const std::string& filename) {
FILE* fp = fopen(filename.c_str(), "wb");
WavHeader header(num_samples_, num_channel_, sample_rate_,
bits_per_sample_);
fwrite(&header, 1, sizeof(header), fp);
for (int i = 0; i < num_samples_; ++i) {
for (int j = 0; j < num_channel_; ++j) {
switch (bits_per_sample_) {
case 8: {
char sample = static_cast<char>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 16: {
int16_t sample = static_cast<int16_t>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
case 32: {
int sample = static_cast<int>(data_[i * num_channel_ + j]);
fwrite(&sample, 1, sizeof(sample), fp);
break;
}
}
}
}
fclose(fp);
}
private:
const float* data_;
int num_samples_; // total float points in data_
int num_channel_;
int sample_rate_;
int bits_per_sample_;
};
class StreamWavWriter {
public:
StreamWavWriter(int num_channel, int sample_rate, int bits_per_sample)
: num_channel_(num_channel),
sample_rate_(sample_rate),
bits_per_sample_(bits_per_sample),
total_num_samples_(0) {}
StreamWavWriter(const std::string& filename, int num_channel,
int sample_rate, int bits_per_sample)
: StreamWavWriter(num_channel, sample_rate, bits_per_sample) {
Open(filename);
}
void Open(const std::string& filename) {
fp_ = fopen(filename.c_str(), "wb");
fseek(fp_, sizeof(WavHeader), SEEK_SET);
}
void Write(const int16_t* sample_data, size_t num_samples) {
fwrite(sample_data, sizeof(int16_t), num_samples, fp_);
total_num_samples_ += num_samples;
}
void Close() {
WavHeader header(total_num_samples_, num_channel_, sample_rate_,
bits_per_sample_);
fseek(fp_, 0L, SEEK_SET);
fwrite(&header, 1, sizeof(header), fp_);
fclose(fp_);
}
private:
FILE* fp_;
int num_channel_;
int sample_rate_;
int bits_per_sample_;
size_t total_num_samples_;
};
} // namespace wenet
#endif // FRONTEND_WAV_H_
# compile wenet.proto
set(PROTO_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
add_custom_command(
OUTPUT ${PROTO_DIR}/wenet.pb.cc
${PROTO_DIR}/wenet.pb.h
${PROTO_DIR}/wenet.grpc.pb.cc
${PROTO_DIR}/wenet.grpc.pb.h
COMMAND ${protobuf_BINARY_DIR}/protoc
ARGS --grpc_out "${PROTO_DIR}"
--cpp_out "${PROTO_DIR}"
-I "${PROTO_DIR}"
--plugin=protoc-gen-grpc=${grpc_BINARY_DIR}/grpc_cpp_plugin
wenet.proto)
# grpc_server/client
link_directories(${protobuf_BINARY_DIR}/lib)
add_library(wenet_grpc STATIC
grpc_client.cc
grpc_server.cc
wenet.pb.cc
wenet.grpc.pb.cc
)
target_link_libraries(wenet_grpc PUBLIC grpc++ grpc++_reflection decoder)
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "grpc/grpc_client.h"
#include "utils/log.h"
namespace wenet {
using grpc::Channel;
using grpc::ClientContext;
using grpc::ClientReaderWriter;
using grpc::Status;
using wenet::Request;
using wenet::Response;
GrpcClient::GrpcClient(const std::string& host, int port, int nbest,
bool continuous_decoding)
: host_(host),
port_(port),
nbest_(nbest),
continuous_decoding_(continuous_decoding) {
Connect();
t_.reset(new std::thread(&GrpcClient::ReadLoopFunc, this));
}
void GrpcClient::Connect() {
channel_ = grpc::CreateChannel(host_ + ":" + std::to_string(port_),
grpc::InsecureChannelCredentials());
stub_ = ASR::NewStub(channel_);
context_ = std::make_shared<ClientContext>();
stream_ = stub_->Recognize(context_.get());
request_ = std::make_shared<Request>();
response_ = std::make_shared<Response>();
request_->mutable_decode_config()->set_nbest_config(nbest_);
request_->mutable_decode_config()->set_continuous_decoding_config(
continuous_decoding_);
stream_->Write(*request_);
}
void GrpcClient::SendBinaryData(const void* data, size_t size) {
const int16_t* pdata = reinterpret_cast<const int16_t*>(data);
request_->set_audio_data(pdata, size);
stream_->Write(*request_);
}
void GrpcClient::ReadLoopFunc() {
try {
while (stream_->Read(response_.get())) {
for (int i = 0; i < response_->nbest_size(); i++) {
// you can also traverse wordpieces like demonstrated above
LOG(INFO) << i + 1 << "best " << response_->nbest(i).sentence();
}
if (response_->status() != Response_Status_ok) {
break;
}
if (response_->type() == Response_Type_speech_end) {
done_ = true;
break;
}
}
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
void GrpcClient::Join() {
stream_->WritesDone();
t_->join();
Status status = stream_->Finish();
if (!status.ok()) {
LOG(INFO) << "Recognize rpc failed.";
}
}
} // namespace wenet
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_GRPC_CLIENT_H_
#define GRPC_GRPC_CLIENT_H_
#include <grpc/grpc.h>
#include <grpcpp/channel.h>
#include <grpcpp/client_context.h>
#include <grpcpp/create_channel.h>
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include "grpc/wenet.grpc.pb.h"
#include "utils/utils.h"
namespace wenet {
using grpc::Channel;
using grpc::ClientContext;
using grpc::ClientReaderWriter;
using wenet::ASR;
using wenet::Request;
using wenet::Response;
class GrpcClient {
public:
GrpcClient(const std::string& host, int port, int nbest,
bool continuous_decoding);
void SendBinaryData(const void* data, size_t size);
void ReadLoopFunc();
void Join();
bool done() const { return done_; }
private:
void Connect();
std::string host_;
int port_;
std::shared_ptr<Channel> channel_{nullptr};
std::unique_ptr<ASR::Stub> stub_{nullptr};
std::shared_ptr<ClientContext> context_{nullptr};
std::unique_ptr<ClientReaderWriter<Request, Response>> stream_{nullptr};
std::shared_ptr<Request> request_{nullptr};
std::shared_ptr<Response> response_{nullptr};
int nbest_ = 1;
bool continuous_decoding_ = false;
bool done_ = false;
std::unique_ptr<std::thread> t_{nullptr};
WENET_DISALLOW_COPY_AND_ASSIGN(GrpcClient);
};
} // namespace wenet
#endif // GRPC_GRPC_CLIENT_H_
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "grpc/grpc_server.h"
namespace wenet {
using grpc::ServerReaderWriter;
using wenet::Request;
using wenet::Response;
GrpcConnectionHandler::GrpcConnectionHandler(
ServerReaderWriter<Response, Request>* stream,
std::shared_ptr<Request> request, std::shared_ptr<Response> response,
std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<DecodeResource> decode_resource)
: stream_(std::move(stream)),
request_(std::move(request)),
response_(std::move(response)),
feature_config_(std::move(feature_config)),
decode_config_(std::move(decode_config)),
decode_resource_(std::move(decode_resource)) {}
void GrpcConnectionHandler::OnSpeechStart() {
LOG(INFO) << "Received speech start signal, start reading speech";
got_start_tag_ = true;
response_->set_status(Response::ok);
response_->set_type(Response::server_ready);
stream_->Write(*response_);
feature_pipeline_ = std::make_shared<FeaturePipeline>(*feature_config_);
decoder_ = std::make_shared<AsrDecoder>(feature_pipeline_, decode_resource_,
*decode_config_);
// Start decoder thread
decode_thread_ = std::make_shared<std::thread>(
&GrpcConnectionHandler::DecodeThreadFunc, this);
}
void GrpcConnectionHandler::OnSpeechEnd() {
LOG(INFO) << "Received speech end signal";
CHECK(feature_pipeline_ != nullptr);
feature_pipeline_->set_input_finished();
got_end_tag_ = true;
}
void GrpcConnectionHandler::OnPartialResult() {
LOG(INFO) << "Partial result";
response_->set_status(Response::ok);
response_->set_type(Response::partial_result);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnFinalResult() {
LOG(INFO) << "Final result";
response_->set_status(Response::ok);
response_->set_type(Response::final_result);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnFinish() {
// Send finish tag
response_->set_status(Response::ok);
response_->set_type(Response::speech_end);
stream_->Write(*response_);
}
void GrpcConnectionHandler::OnSpeechData() {
// Read binary PCM data
const int16_t* pcm_data =
reinterpret_cast<const int16_t*>(request_->audio_data().c_str());
int num_samples = request_->audio_data().length() / sizeof(int16_t);
VLOG(2) << "Received " << num_samples << " samples";
CHECK(feature_pipeline_ != nullptr);
CHECK(decoder_ != nullptr);
feature_pipeline_->AcceptWaveform(pcm_data, num_samples);
}
void GrpcConnectionHandler::SerializeResult(bool finish) {
for (const DecodeResult& path : decoder_->result()) {
Response_OneBest* one_best_ = response_->add_nbest();
one_best_->set_sentence(path.sentence);
if (finish) {
for (const WordPiece& word_piece : path.word_pieces) {
Response_OnePiece* one_piece_ = one_best_->add_wordpieces();
one_piece_->set_word(word_piece.word);
one_piece_->set_start(word_piece.start);
one_piece_->set_end(word_piece.end);
}
}
if (response_->nbest_size() == nbest_) {
break;
}
}
return;
}
void GrpcConnectionHandler::DecodeThreadFunc() {
while (true) {
DecodeState state = decoder_->Decode();
response_->clear_status();
response_->clear_type();
response_->clear_nbest();
if (state == DecodeState::kEndFeats) {
decoder_->Rescoring();
SerializeResult(true);
OnFinalResult();
OnFinish();
stop_recognition_ = true;
break;
} else if (state == DecodeState::kEndpoint) {
decoder_->Rescoring();
SerializeResult(true);
OnFinalResult();
// If it's not continuous decoding, continue to do next recognition
// otherwise stop the recognition
if (continuous_decoding_) {
decoder_->ResetContinuousDecoding();
} else {
OnFinish();
stop_recognition_ = true;
break;
}
} else {
if (decoder_->DecodedSomething()) {
SerializeResult(false);
OnPartialResult();
}
}
}
}
void GrpcConnectionHandler::operator()() {
try {
while (stream_->Read(request_.get())) {
if (!got_start_tag_) {
nbest_ = request_->decode_config().nbest_config();
continuous_decoding_ =
request_->decode_config().continuous_decoding_config();
OnSpeechStart();
} else {
OnSpeechData();
}
}
OnSpeechEnd();
LOG(INFO) << "Read all pcm data, wait for decoding thread";
if (decode_thread_ != nullptr) {
decode_thread_->join();
}
} catch (std::exception const& e) {
LOG(ERROR) << e.what();
}
}
Status GrpcServer::Recognize(ServerContext* context,
ServerReaderWriter<Response, Request>* stream) {
LOG(INFO) << "Get Recognize request" << std::endl;
auto request = std::make_shared<Request>();
auto response = std::make_shared<Response>();
GrpcConnectionHandler handler(stream, request, response, feature_config_,
decode_config_, decode_resource_);
std::thread t(std::move(handler));
t.join();
return Status::OK;
}
} // namespace wenet
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef GRPC_GRPC_SERVER_H_
#define GRPC_GRPC_SERVER_H_
#include <iostream>
#include <memory>
#include <string>
#include <thread>
#include <utility>
#include <vector>
#include "decoder/asr_decoder.h"
#include "frontend/feature_pipeline.h"
#include "utils/log.h"
#include "grpc/wenet.grpc.pb.h"
namespace wenet {
using grpc::ServerContext;
using grpc::ServerReaderWriter;
using grpc::Status;
using wenet::ASR;
using wenet::Request;
using wenet::Response;
class GrpcConnectionHandler {
public:
GrpcConnectionHandler(ServerReaderWriter<Response, Request>* stream,
std::shared_ptr<Request> request,
std::shared_ptr<Response> response,
std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<DecodeResource> decode_resource);
void operator()();
private:
void OnSpeechStart();
void OnSpeechEnd();
void OnFinish();
void OnSpeechData();
void OnPartialResult();
void OnFinalResult();
void DecodeThreadFunc();
void SerializeResult(bool finish);
bool continuous_decoding_ = false;
int nbest_ = 1;
ServerReaderWriter<Response, Request>* stream_;
std::shared_ptr<Request> request_;
std::shared_ptr<Response> response_;
std::shared_ptr<FeaturePipelineConfig> feature_config_;
std::shared_ptr<DecodeOptions> decode_config_;
std::shared_ptr<DecodeResource> decode_resource_;
bool got_start_tag_ = false;
bool got_end_tag_ = false;
// When endpoint is detected, stop recognition, and stop receiving data.
bool stop_recognition_ = false;
std::shared_ptr<FeaturePipeline> feature_pipeline_ = nullptr;
std::shared_ptr<AsrDecoder> decoder_ = nullptr;
std::shared_ptr<std::thread> decode_thread_ = nullptr;
};
class GrpcServer final : public ASR::Service {
public:
GrpcServer(std::shared_ptr<FeaturePipelineConfig> feature_config,
std::shared_ptr<DecodeOptions> decode_config,
std::shared_ptr<DecodeResource> decode_resource)
: feature_config_(std::move(feature_config)),
decode_config_(std::move(decode_config)),
decode_resource_(std::move(decode_resource)) {}
Status Recognize(ServerContext* context,
ServerReaderWriter<Response, Request>* reader) override;
private:
std::shared_ptr<FeaturePipelineConfig> feature_config_;
std::shared_ptr<DecodeOptions> decode_config_;
std::shared_ptr<DecodeResource> decode_resource_;
DISALLOW_COPY_AND_ASSIGN(GrpcServer);
};
} // namespace wenet
#endif // GRPC_GRPC_SERVER_H_
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
option java_package = "ex.grpc";
option objc_class_prefix = "wenet";
package wenet;
service ASR {
rpc Recognize (stream Request) returns (stream Response) {}
}
message Request {
message DecodeConfig {
int32 nbest_config = 1;
bool continuous_decoding_config = 2;
}
oneof RequestPayload {
DecodeConfig decode_config = 1;
bytes audio_data = 2;
}
}
message Response {
message OneBest {
string sentence = 1;
repeated OnePiece wordpieces = 2;
}
message OnePiece {
string word = 1;
int32 start = 2;
int32 end = 3;
}
enum Status {
ok = 0;
failed = 1;
}
enum Type {
server_ready = 0;
partial_result = 1;
final_result = 2;
speech_end = 3;
}
Status status = 1;
Type type = 2;
repeated OneBest nbest = 3;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment