Commit 688b6eac authored by SWHL's avatar SWHL
Browse files

Update files

parents
accum_grad: 16
cmvn_file: exp/conformer/global_cmvn
dataset_conf:
batch_conf:
batch_size: 32
batch_type: static
fbank_conf:
dither: 1.0
frame_length: 25
frame_shift: 10
num_mel_bins: 80
filter_conf:
max_length: 1200
min_length: 10
token_max_length: 100
token_min_length: 1
resample_conf:
resample_rate: 16000
shuffle: true
shuffle_conf:
shuffle_size: 1500
sort: true
sort_conf:
sort_size: 1000
spec_aug: true
spec_aug_conf:
max_f: 30
max_t: 50
num_f_mask: 2
num_t_mask: 2
speed_perturb: false
decoder: transformer
decoder_conf:
attention_heads: 8
dropout_rate: 0.1
linear_units: 2048
num_blocks: 6
positional_dropout_rate: 0.1
self_attention_dropout_rate: 0.0
src_attention_dropout_rate: 0.0
encoder: conformer
encoder_conf:
activation_type: swish
attention_dropout_rate: 0.0
attention_heads: 8
cnn_module_kernel: 15
cnn_module_norm: layer_norm
dropout_rate: 0.1
input_layer: conv2d
linear_units: 2048
normalize_before: true
num_blocks: 12
output_size: 512
pos_enc_layer_type: rel_pos
positional_dropout_rate: 0.1
selfattention_layer_type: rel_selfattn
use_cnn_module: true
grad_clip: 5
input_dim: 80
is_json_cmvn: true
log_interval: 100
max_epoch: 50
model_conf:
ctc_weight: 0.3
length_normalized_loss: false
lsm_weight: 0.1
optim: adam
optim_conf:
lr: 0.001
output_dim: 5537
scheduler: warmuplr
scheduler_conf:
warmup_steps: 5000
## download models
```
URL:https://pan.baidu.com/s/1BTR-uR_8WWBFpvOisNR_PA
CODE:9xjz
```
#!/bin/bash
export PYTHONPATH=/root/wenet-onnx/wenet # the directory of wenet root.
export CUDA_VISIBLE_DEVICES="1" # gpu id
FP16=--fp16
SRCDIR=/root/wenet-onnx/models/$1 #the directory of source path of checkpoint models.
OUTDIR=onnx_$1
python export_onnx.py --config $SRCDIR/train.yaml --checkpoint $SRCDIR/final.pt --cmvn_file $SRCDIR/global_cmvn $FP16 --output_onnx_dir $OUTDIR
\ No newline at end of file
Change the value of default directories in export_onnx_model.sh
Put script export_onnx_model.sh into wenet/wenet/bin/
recognize_onnx.py 测试数据格式,每行一条
```
{"key": "mywave", "wav": "test.wav", "txt": ""}
```
#!/bin/bash
MODELDIR=/root/wenet-onnx/wenet/wenet/bin/models/onnx_20211025_conformer_exp
python recognize_onnx.py --test_data $1 --dict $MODELDIR/words.txt --config $MODELDIR/train.yaml --encoder_onnx $MODELDIR/encoder.onnx --decoder_onnx $MODELDIR/decoder.onnx --result_file result.txt
~
#!/bin/bash
FP16=--fp16
MODELDIR=/root/wenet-onnx/wenet/wenet/bin/models/onnx_20211025_conformer_exp
python recognize_onnx.py --test_data $1 --dict $MODELDIR/words.txt --config $MODELDIR/train.yaml --encoder_onnx $MODELDIR/encoder_fp16.onnx --decoder_onnx $MODELDIR/decoder_fp16.onnx --result_file result.txt $FP16
stuff
\ No newline at end of file
SDK for windows.
#include "precomp.h"
#ifdef __cplusplus
extern "C" {
#endif
// APIs for rapidasr
_RAPIDASRAPI RAPIDASR_HANDLE RpASR_init(const char* szModelDir,int nThread)
{
CQmASRRecog* pObj = new CQmASRRecog(szModelDir, nThread);
if (pObj)
{
if (pObj->IsLoaded())
return pObj;
else
delete pObj;
}
return nullptr;
}
_RAPIDASRAPI RAPIDASR_RESULT RpASRRecogBuffer(RAPIDASR_HANDLE handle, const char* szBuf, int nLen, RAPIDASR_MODE Mode)
{
CQmASRRecog* pRecogObj = (CQmASRRecog*)handle;
if (!pRecogObj)
return nullptr;
vector<float> wav;
wenet::WavReaderMem Reader(szBuf,nLen, wav);
assert(Reader.sample_rate() == Reader.sample_rate());
wenet::FeaturePipelineConfig config(QM_FEATURE_DIMENSION, QM_DEFAULT_SAMPLE_RATE);
vector<vector<float>> feats;
if (pRecogObj->ExtractFeature(wav, feats, config) > 0)
return pRecogObj->DoRecognize(feats, Mode);
else
return nullptr;
}
_RAPIDASRAPI RAPIDASR_RESULT RpASRRecogFile(RAPIDASR_HANDLE handle, const char* szWavfile, RAPIDASR_MODE Mode)
{
CQmASRRecog* pRecogObj = (CQmASRRecog*)handle;
if (!pRecogObj)
return nullptr;
vector<float> wav;
wenet::WavReader Reader(szWavfile, wav);
assert(Reader.sample_rate() == Reader.sample_rate());
wenet::FeaturePipelineConfig config(QM_FEATURE_DIMENSION, Reader.sample_rate());
vector<vector<float>> feats;
if (pRecogObj->ExtractFeature(wav, feats,config) > 0)
return pRecogObj->DoRecognize(feats,Mode);
else
return nullptr;
}
_RAPIDASRAPI const int RpASRGetRetNumber(RAPIDASR_RESULT Result)
{
if (!Result)
return 0;
PRAPIDASR_RECOG_RESULT pResult = (PRAPIDASR_RECOG_RESULT)Result;
return pResult->Strings.size();
}
_RAPIDASRAPI const char* RpASRGetResult(RAPIDASR_RESULT Result,int nIndex)
{
PRAPIDASR_RECOG_RESULT pResult = (PRAPIDASR_RECOG_RESULT)Result;
if(!pResult)
return nullptr;
if (nIndex >= pResult->Strings.size())
return nullptr;
return pResult->Strings[nIndex].c_str();
}
_RAPIDASRAPI void RpASRFreeResult(RAPIDASR_RESULT Result)
{
if (Result)
{
delete PRAPIDASR_RECOG_RESULT(Result);
}
}
_RAPIDASRAPI void RpASR_Uninit(RAPIDASR_HANDLE handle)
{
CQmASRRecog* pRecogObj = (CQmASRRecog*)handle;
if (!pRecogObj)
return;
delete pRecogObj;
}
#ifdef __cplusplus
}
#endif
#include "precomp.h"
bool CQmASRRecog::IsLoaded()
{
return m_bIsLoaded;
}
CQmASRRecog::CQmASRRecog(const char* szModelDir,int nThread)
{
m_bIsLoaded = LoadModel(szModelDir,nThread);
}
CQmASRRecog::CQmASRRecog(const char* szEncoder, const char* szDecoder, const char* szDict, const char* szConfig,int nThread)
{
m_bIsLoaded = LoadModel(szEncoder, szDecoder, szDict, szConfig,nThread);
}
bool CQmASRRecog::LoadModel(const char* szModelDir,int nNumThread)
{
string strEncoder, strDecoder;
string strBaseDir= szModelDir;
if (!szModelDir)
return false;
if (szModelDir[strlen(szModelDir) - 1] != OS_SEP[0])
{
strBaseDir = strBaseDir + OS_SEP;
}
strEncoder = strBaseDir + QM_ENCODER_MODEL;
strDecoder = strBaseDir + QM_DECODER_MODEL;
m_strDict = strBaseDir + QM_DICT_FILE;
m_strConfig = strBaseDir + QM_CONFIG_FILE;
return LoadModel(strEncoder.c_str(), strDecoder.c_str(), m_strDict.c_str(), m_strConfig.c_str(),nNumThread);
}
bool CQmASRRecog::LoadModel(const char* szEncoder, const char* szDecoder, const char* szDict, const char* szConfig, int nNumThread)
{
sessionOptions.SetInterOpNumThreads(nNumThread);
sessionOptions.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
m_session_encoder = new Ort::Session(envEncoder, szEncoder, sessionOptions);
m_session_decoder = new Ort::Session(envDecoder, szDecoder, sessionOptions);
getInputNameAll(m_session_encoder, m_vecEncInputName);
for (auto& item : m_vecEncInputName)
m_strEncInputName.push_back(item.c_str());
getOutputNameAll (m_session_encoder, m_vecEncOutputName);
for (auto& item : m_vecEncOutputName)
m_strEncOutputName.push_back(item.c_str());
getInputNameAll (m_session_decoder, m_vecDecInputName);
for (auto& item : m_vecDecInputName)
m_strDecInputName.push_back(item.c_str());
getOutputNameAll(m_session_decoder, m_vecDecOutputName);
for (auto& item : m_vecDecOutputName)
m_strDecOutputName.push_back(item.c_str());
// load vocabulary
ifstream fdict(szDict);
if (!fdict.is_open())
return false;
char strLine[101];
string strToken;
int nIndex;
while(fdict.getline(strLine,100))
{
stringstream sstr;
sstr.str(strLine);
sstr >> strToken;
sstr >> nIndex;
m_Vocabulary.push_back(strToken);
}
// load config
//model_conf:
// ctc_weight: 0.3
// length_normalized_loss : false
// lsm_weight : 0.1
// reverse_weight : 0.3
try
{
YAML::Node conf = YAML::LoadFile(szConfig);
auto var = conf["model_conf"]["reverse_weight"];
try {
m_reverse_weight = var.as<float>();
}
catch (YAML::TypedBadConversion<float>& e) {
//std::cout << "label node is NULL" << std::endl;
m_reverse_weight = 0.0f;
}
}
catch (YAML::BadFile& e)
{
m_reverse_weight = 0.0f;
}
return true;
}
CQmASRRecog::~CQmASRRecog()
{
if (m_session_encoder)
{
delete m_session_encoder;
m_session_encoder = nullptr;
}
if (m_session_decoder)
{
delete m_session_decoder;
m_session_decoder = nullptr;
}
}
//
// By default, it has a sample rate of 16bits.
//
//
// https://blog.csdn.net/hongmaodaxia/article/details/44224825
// http://fancyerii.github.io/kaldicodes/feature/
// https://github.com/kli017/wenet/tree/wenet-ort
int CQmASRRecog::ExtractFeature(vector<float> & wav, std::vector<std::vector<float>>& feats, wenet::FeaturePipelineConfig& config)
{
wenet::Fbank fbank_(config.num_bins,config.sample_rate,config.frame_length,config.frame_shift);
//std::vector<float> waves;
//waves.insert(waves.end(), wav.begin(), wav.end());
// //waves.insert(waves.end(), remained_wav_.begin(), remained_wav_.end());
//waves.insert(waves.end(), wav.begin(), wav.end());
int num_frames = fbank_.Compute(wav, &feats);
//for (size_t i = 0; i < feats.size(); ++i) {
// feature_queue_.Push(std::move(feats[i]));
//}
return num_frames;
//int num_frames = fbank_.Compute(waves, &feats);
//for (size_t i = 0; i < feats.size(); ++i) {
// feature_queue_.Push(std::move(feats[i]));
//}
//num_frames_ += num_frames;
//int left_samples = waves.size() - config_.frame_shift * num_frames;
//remained_wav_.resize(left_samples);
//std::copy(waves.begin() + config_.frame_shift * num_frames, waves.end(),
// remained_wav_.begin());
// We are still adding wave, notify input is not finished
}
PRAPIDASR_RECOG_RESULT CQmASRRecog::DoRecognize(vector<vector<float>> & feats, RAPIDASR_MODE Mode)
{
PRAPIDASR_RECOG_RESULT pResult= new RAPIDASR_RECOG_RESULT;
pResult->Result = QAC_ERROR;
// for encoder model
Ort::RunOptions run_option{nullptr};
int num_frames = feats.size();
auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
std::vector<int64_t> input_mask_node_dims = { 1, num_frames, QM_FEATURE_DIMENSION };
std::vector<float> flatfeats;
for (auto& e : feats)
{
flatfeats.insert(flatfeats.end(), e.begin(), e.end());
}
Ort::Value onnx_feats =Ort::Value::CreateTensor<float>(memory_info,
flatfeats.data(),
flatfeats.size(),
input_mask_node_dims.data(),
input_mask_node_dims.size());
std::vector<int32_t> feats_len{ num_frames };
std::vector<int64_t> feats_len_dim{1};
Ort::Value onnx_feats_len = Ort::Value::CreateTensor(
memory_info,
feats_len.data(),
feats_len.size()*sizeof(int32_t),
feats_len_dim.data(),
feats_len_dim.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
std::vector<Ort::Value> input_onnx;
input_onnx.emplace_back(std::move(onnx_feats));
input_onnx.emplace_back(std::move(onnx_feats_len));
auto output= m_session_encoder->Run(run_option,
m_strEncInputName.data(),
input_onnx.data(),
m_strEncInputName.size(),
m_strEncOutputName.data(),
m_strEncOutputName.size()
);
// encoder_out, encoder_out_lens, ctc_log_probs, beam_log_probs, beam_log_probs_idx = output
assert(output.size() == 5 && output[0].IsTensor());
vector<int64_t> shape_encoder_out = output[0].GetTensorTypeAndShapeInfo().GetShape();
auto encoder_out = output[0].GetTensorMutableData<float>();
//int nLen= std::accumulate(shape_encode_out.begin(), shape_encode_out.end(), 1, std::multiplies<int64_t>());
vector<int64_t> shape_encoder_out_lens= output[1].GetTensorTypeAndShapeInfo().GetShape();
auto encoder_out_lens = output[1].GetTensorMutableData<int32>();
vector<int64_t> shape_ctc_log_probs = output[2].GetTensorTypeAndShapeInfo().GetShape();
auto ctc_log_probs = output[2].GetTensorMutableData<float>();
vector<int64_t> shape_beam_log_probs = output[3].GetTensorTypeAndShapeInfo().GetShape();
auto beam_log_probs = output[3].GetTensorMutableData<float>();
vector<int64_t> shape_beam_log_probs_idx = output[4].GetTensorTypeAndShapeInfo().GetShape();
auto beam_log_probs_idx = output[4].GetTensorMutableData<int64>();
auto beam_size = shape_beam_log_probs[2];
auto batch_size = shape_beam_log_probs[0];
int num_process = 2; // the number of processors.
int sos, eos;
sos = eos = m_Vocabulary.size() -1;
if (Mode == RPASRM_CTC_GREEDY_SEARCH) //ctc greedy search
{
if (beam_size != 1)
{
vector<int> log_probs_idx;
for (int i = 0; i < shape_beam_log_probs_idx[1]; i++)
{
//log_probs_idx = beam_log_probs_idx[:, : , 0]
log_probs_idx.push_back(*(beam_log_probs_idx+shape_beam_log_probs_idx[2]*i));
}
vector<vector<int>> batch_sents;
batch_sents.push_back(log_probs_idx);
auto sentence = map_batch(batch_sents, m_Vocabulary, num_process,true,0);
pResult->Strings = sentence;
}
}
else
if (Mode == RPASRM_CTC_RPEFIX_BEAM_SEARCH || Mode == RPASRM_ATTENSION_RESCORING)
{
vector<vector<vector<double>>> batch_log_probs_seq;
vector<vector<vector<int>>> batch_log_probs_idx;
vector<PathTrie*> batch_root;
vector<bool> batch_start;
size_t beam_size= shape_beam_log_probs[2];
size_t batch_size= shape_beam_log_probs[0];;
for (int i = 0; i < shape_encoder_out_lens[0]; i++)
{
auto num_sent = encoder_out_lens[i];
vector <vector<double>> batch_log_probs_seq_list;
vector <vector<int>> batch_log_probs_index_list;
for (int s = 0; s < num_sent; s++)
{
vector<double> temp;
for (int t = 0; t < shape_beam_log_probs[2]; t++)
temp.push_back(beam_log_probs[s * shape_beam_log_probs[2] + t]);
batch_log_probs_seq_list.push_back(temp);
vector<int> tempindex;
for (int t = 0; t < shape_beam_log_probs_idx[2]; t++)
tempindex.push_back(beam_log_probs_idx[s * shape_beam_log_probs_idx[2] + t]);
batch_log_probs_index_list.push_back(tempindex);
}
batch_root.push_back(new PathTrie);
batch_log_probs_seq.push_back(batch_log_probs_seq_list);
batch_log_probs_idx.push_back(batch_log_probs_index_list);
batch_start.push_back(true);
}
auto score_hyps=ctc_beam_search_decoder_batch(batch_log_probs_seq, batch_log_probs_idx, batch_root, batch_start, beam_size,num_process,0, -2, 0.99999);
if (Mode == RPASRM_CTC_RPEFIX_BEAM_SEARCH)
{
vector<std::vector<int>> batch_sents;
for (auto& item :score_hyps)
{
batch_sents.push_back( item[0].second);
}
auto sentences = map_batch(batch_sents, m_Vocabulary, num_process, false, 0);
pResult->Strings = sentences;
}
if (Mode == RPASRM_ATTENSION_RESCORING)
{
int max_len = 0;
vector<vector<float>> ctc_score;
vector<vector<int>> all_hyps;
for (auto& hyps : score_hyps)
{
auto cur_len = hyps.size();
if ( cur_len < beam_size)
{
vector<int> tmp;
for (int s = 0; s < hyps[0].second.size(); s++)
tmp.push_back(0);
for (int i = 0; i< beam_size - cur_len; i++)
{
hyps.push_back(std::make_pair(-999999999999, tmp));
}
// hyps += (beam_size - cur_len) * [(-float("INF"), (0, ))]
}
vector<float> cur_ctc_score;
for (auto& hyp : hyps)
{
cur_ctc_score.push_back(hyp.first);
all_hyps.push_back(hyp.second);
if (hyp.second.size() > max_len)
max_len = hyp.second.size();
}
ctc_score.push_back(cur_ctc_score);
}
// hyps_pad_sos_eos
// r_hyps_pad_sos_eos
auto lastdim = max_len + 2;
auto eos_len = batch_size * beam_size * lastdim;
vector<int64_t> hyps_pad_sos_eos(eos_len, IGNORE_ID);
vector<int64_t> r_hyps_pad_sos_eos(eos_len, IGNORE_ID);
vector<int32_t> hyps_lens_sos(batch_size* beam_size,1);
int k = 0;
for (int i = 0; i < batch_size; i++)
{
for (int j = 0; j < beam_size; j++)
{
vector<int64_t> tmp,rtmp;
auto cand = all_hyps[k];
auto rcand = cand;
reverse(rcand.begin(), rcand.end());
int l = cand.size() + 2;
tmp.push_back(sos);
rtmp.push_back(sos);
tmp.insert(tmp.begin()+1,cand.begin(), cand.end());
rtmp.insert(rtmp.begin()+1,rcand.begin(), rcand.end());
tmp.push_back(eos);
rtmp.push_back(eos);
copy( tmp.begin(), tmp.end(), hyps_pad_sos_eos.begin() + i * j * lastdim);
copy( rtmp.begin(), rtmp.end(), r_hyps_pad_sos_eos.begin() + i * j * lastdim);
hyps_lens_sos[beam_size * i + j] = cand.size() + 1;
k++;
}
}
Ort::Value onnx_encoder_out = Ort::Value::CreateTensor<float>(memory_info,
encoder_out,
accumulate(shape_encoder_out.begin(), shape_encoder_out.end(), 1, multiplies<int>()),
shape_encoder_out.data(),
shape_encoder_out.size()
);
Ort::Value onnx_encoder_out_len = Ort::Value::CreateTensor<int32_t>(memory_info,
encoder_out_lens,
accumulate(shape_encoder_out_lens.begin(), shape_encoder_out_lens.end(), 1, multiplies<int>()),
shape_encoder_out_lens.data(),
shape_encoder_out_lens.size());
std::vector<int64_t> hyps_pad_sos_eos_dims = { batch_size, beam_size, lastdim };
Ort::Value onnx_hyps_pad_sos_eos = Ort::Value::CreateTensor<int64_t>(memory_info,
hyps_pad_sos_eos.data(),
hyps_pad_sos_eos.size(),
hyps_pad_sos_eos_dims.data(),
hyps_pad_sos_eos_dims.size());
//hyps_pad_sos_eos
std::vector<int64_t> r_hyps_pad_sos_eos_dims = { batch_size, beam_size, lastdim };
Ort::Value onnx_r_hyps_pad_sos_eos = Ort::Value::CreateTensor<int64_t>(memory_info,
r_hyps_pad_sos_eos.data(),
r_hyps_pad_sos_eos.size(),
r_hyps_pad_sos_eos_dims.data(),
r_hyps_pad_sos_eos_dims.size());
std::vector<int64_t> hyps_len_sos_dims = { batch_size, beam_size };
Ort::Value onnx_hyps_len_sos = Ort::Value::CreateTensor<int32_t>(memory_info,
hyps_lens_sos.data(),
hyps_lens_sos.size(),
hyps_len_sos_dims.data(),
hyps_len_sos_dims.size());
std::vector<float> flat_ctc_score;
for (auto& e : ctc_score)
{
flat_ctc_score.insert(flat_ctc_score.end(), e.begin(), e.end());
}
std::vector<int64_t> ctc_score_dims = { batch_size, beam_size };
Ort::Value onnx_ctc_score = Ort::Value::CreateTensor<float>(memory_info,
flat_ctc_score.data(),
flat_ctc_score.size(),
ctc_score_dims.data(),
ctc_score_dims.size());
std::vector<Ort::Value> input_onnx;
input_onnx.emplace_back(std::move(onnx_encoder_out));
input_onnx.emplace_back(std::move(onnx_encoder_out_len));
input_onnx.emplace_back(std::move(onnx_hyps_pad_sos_eos));
input_onnx.emplace_back(std::move(onnx_hyps_len_sos));
if(m_reverse_weight)
input_onnx.emplace_back(std::move(onnx_r_hyps_pad_sos_eos));
else
input_onnx.emplace_back(std::move(onnx_ctc_score));
auto decoder_output = m_session_decoder->Run(run_option,
m_strDecInputName.data(),
input_onnx.data(),
m_strDecInputName.size(),
m_strDecOutputName.data(),
m_strDecOutputName.size()
);
assert(decoder_output.size() == 1 && decoder_output[0].IsTensor());
//auto best_index = decoder_output[0];
vector<int64_t> shape_best_index = decoder_output[0].GetTensorTypeAndShapeInfo().GetShape();
auto best_index = decoder_output[0].GetTensorMutableData<int64_t>();
k = 0;
vector<vector<int>> batch_sents;
for (int i=0; i< shape_best_index[0]; i++)
{
if (best_index[i] >= (beam_size - k)) // 如果index 大于 选出的个数,则非法。
return nullptr;
batch_sents.push_back(all_hyps[k+ best_index[i]]);
k += beam_size;
}
auto sentences = map_batch(batch_sents, m_Vocabulary, num_process);
pResult->Strings = sentences;
}
for (auto& item : batch_root)
delete item;
}
pResult->Result = QAC_OK;
return pResult;
}
#include <stdlib.h>
#include <stdio.h>
#include "librpasrapi.h"
#define TEST_WAV "/opt/test/test.wav"
#define MODEL_DIR "/opt/test/models/onnx_20211025_conformer_exp"
int main(int argc, char * argv[])
{
auto Handle =RpASR_init(MODEL_DIR, RP_DEFAULT_THREAD_NUM);
if (!Handle)
{
printf("Can't load models from %s\n", MODEL_DIR);
return -1;
}
auto retHandle =RpASRRecogFile(Handle, TEST_WAV, RPASRM_ATTENSION_RESCORING); // RPASRM_CTC_GREEDY_SEARCH); // RPASRM_ATTENSION_RESCORING);
int nNumber =RpASRGetRetNumber(retHandle);
printf(" %d results. String:", nNumber);
const char * szString =RpASRGetResult(retHandle, 0);
printf(szString);
printf("\n");
if (retHandle)
RpASRFreeResult(retHandle);
RpASR_Uninit(Handle);
return 0;
}
\ No newline at end of file
@echo off
set CURDIR=%cd%
rem win32 |x64
call:CompileLib %CURDIR% x64
cd %CURDIR%
call:CompileLib %CURDIR% win32
cd %CURDIR%
rem 编译函数
:CompileLib
cd %~1
cd openfst
if exist build ( rd /q /s build )
mkdir build
cd build
cmake .. -A %~2 -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=../winlib/%~2
cmake --build . --config Release -j8
GOTO:EOF
cmake_minimum_required(VERSION 3.15)
project(ctc_decoder_prj)
set(LIBNAME ctc_decoder)
set(APPNAME "testapp")
add_compile_options(-fPIC)
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if(WIN32)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/win-dep/include)
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/win-dep/lib/x64)
else()
include_directories(/opt/qmcds/include /opt/qmcds/include/kenlm)
link_directories(/opt/qmcds/lib)
endif()
set(MAIN_SRC "ctc_beam_search_decoder.cpp" "decoder_utils.cpp" "path_trie.cpp" "scorer.cpp")
add_library(${LIBNAME} SHARED ${MAIN_SRC})
target_link_libraries(${LIBNAME} PUBLIC pthread fst kenlm kenlm_builder kenlm_filter kenlm_interpolate kenlm_util bz2 lzma z)
add_executable(${APPNAME} test/test.cpp)
target_link_libraries(${APPNAME} PUBLIC ctc_decoder kenlm kenlm_builder kenlm_filter kenlm_interpolate kenlm_util)
{
"configurations": [
{
"name": "x64-Debug",
"generator": "Ninja",
"configurationType": "Debug",
"inheritEnvironments": [ "msvc_x64_x64" ],
"buildRoot": "${projectDir}\\out\\build\\${name}",
"installRoot": "${projectDir}\\out\\install\\${name}",
"cmakeCommandArgs": "",
"buildCommandArgs": "",
"ctestCommandArgs": ""
},
{
"name": "Linux-GCC-Debug",
"generator": "Unix Makefiles",
"configurationType": "Debug",
"cmakeExecutable": "cmake",
"remoteCopySourcesExclusionList": [ ".vs", ".git", "out" ],
"cmakeCommandArgs": "",
"buildCommandArgs": "",
"ctestCommandArgs": "",
"inheritEnvironments": [ "linux_x64" ],
"remoteMachineName": "-1892815710;10.95.24.25 (username=root, port=33, authentication=Password)",
"remoteCMakeListsRoot": "$HOME/.vs/${projectDirName}/${workspaceHash}/src",
"remoteBuildRoot": "$HOME/.vs/${projectDirName}/${workspaceHash}/out/build/${name}",
"remoteInstallRoot": "$HOME/.vs/${projectDirName}/${workspaceHash}/out/install/${name}",
"remoteCopySources": true,
"rsyncCommandArgs": "-t --delete --delete-excluded",
"remoteCopyBuildOutput": false,
"remoteCopySourcesMethod": "rsync"
}
]
}
\ No newline at end of file
Copyright (c) 2012 Jakob Progsch, Václav Zeman
This software is provided 'as-is', without any express or implied
warranty. In no event will the authors be held liable for any damages
arising from the use of this software.
Permission is granted to anyone to use this software for any purpose,
including commercial applications, and to alter it and redistribute it
freely, subject to the following restrictions:
1. The origin of this software must not be misrepresented; you must not
claim that you wrote the original software. If you use this software
in a product, an acknowledgment in the product documentation would be
appreciated but is not required.
2. Altered source versions must be plainly marked as such, and must not be
misrepresented as being the original software.
3. This notice may not be removed or altered from any source
distribution.
ThreadPool
==========
A simple C++11 Thread Pool implementation.
Basic usage:
```c++
// create thread pool with 4 worker threads
ThreadPool pool(4);
// enqueue and store future
auto result = pool.enqueue([](int answer) { return answer; }, 42);
// get result from future
std::cout << result.get() << std::endl;
```
#ifndef THREAD_POOL_H
#define THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for(size_t i = 0;i<threads;++i)
workers.emplace_back(
[this]
{
for(;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this]{ return this->stop || !this->tasks.empty(); });
if(this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if(stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task](){ (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for(std::thread &worker: workers)
worker.join();
}
#endif
#include <iostream>
#include <vector>
#include <chrono>
#include "ThreadPool.h"
int main()
{
ThreadPool pool(4);
std::vector< std::future<int> > results;
for(int i = 0; i < 8; ++i) {
results.emplace_back(
pool.enqueue([i] {
std::cout << "hello " << i << std::endl;
std::this_thread::sleep_for(std::chrono::seconds(1));
std::cout << "world " << i << std::endl;
return i*i;
})
);
}
for(auto && result: results)
std::cout << result.get() << ' ';
std::cout << std::endl;
return 0;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment