Commit 764b3a75 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
# Copyright (c) 2022 Mddct(hamddct@gmail.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import tarfile
from pathlib import Path
from urllib.request import urlretrieve
import tqdm
def download(url: str, dest: str, only_child=True):
""" download from url to dest
"""
assert os.path.exists(dest)
print('Downloading {} to {}'.format(url, dest))
def progress_hook(t):
last_b = [0]
def update_to(b=1, bsize=1, tsize=None):
if tsize not in (None, -1):
t.total = tsize
displayed = t.update((b - last_b[0]) * bsize)
last_b[0] = b
return displayed
return update_to
# *.tar.gz
name = url.split("/")[-1]
tar_path = os.path.join(dest, name)
with tqdm.tqdm(unit='B',
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=(name)) as t:
urlretrieve(url,
filename=tar_path,
reporthook=progress_hook(t),
data=None)
t.total = t.n
with tarfile.open(tar_path) as f:
if not only_child:
f.extractall(dest)
else:
for tarinfo in f:
if "/" not in tarinfo.name:
continue
name = os.path.basename(tarinfo.name)
fileobj = f.extractfile(tarinfo)
with open(os.path.join(dest, name), "wb") as writer:
writer.write(fileobj.read())
class Hub(object):
"""Hub for wenet pretrain runtime model
"""
# TODO(Mddct): make assets class to support other language
Assets = {
# wenetspeech
"chs":
"https://github.com/wenet-e2e/wenet/releases/download/v2.0.1/chs.tar.gz",
# gigaspeech
"en":
"https://github.com/wenet-e2e/wenet/releases/download/v2.0.1/en.tar.gz"
}
def __init__(self) -> None:
pass
@staticmethod
def get_model_by_lang(lang: str) -> str:
assert lang in Hub.Assets.keys()
# NOTE(Mddct): model_dir structure
# Path.Home()/.went
# - chs
# - units.txt
# - final.zip
# - en
# - units.txt
# - final.zip
model_url = Hub.Assets[lang]
model_dir = os.path.join(Path.home(), ".wenet", lang)
if not os.path.exists(model_dir):
os.makedirs(model_dir)
# TODO(Mddct): model metadata
if set(["final.zip",
"units.txt"]).issubset(set(os.listdir(model_dir))):
return model_dir
download(model_url, model_dir, only_child=True)
return model_dir
#!/usr/bin/env python3
# Copyright (c) 2020 Xiaomi Corporation (author: Fangjun Kuang)
# 2022 Binbin Zhang(binbzha@qq.com)
import glob
import os
import shutil
import sys
import setuptools
from setuptools.command.build_ext import build_ext
def cmake_extension(name, *args, **kwargs) -> setuptools.Extension:
kwargs["language"] = "c++"
sources = []
return setuptools.Extension(name, sources, *args, **kwargs)
class BuildExtension(build_ext):
def build_extension(self, ext: setuptools.extension.Extension):
os.makedirs(self.build_temp, exist_ok=True)
os.makedirs(self.build_lib, exist_ok=True)
cmake_args = os.environ.get("WENET_CMAKE_ARGS",
"-DCMAKE_BUILD_TYPE=Release")
if "PYTHON_EXECUTABLE" not in cmake_args:
print(f"Setting PYTHON_EXECUTABLE to {sys.executable}")
cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}"
src_dir = os.path.dirname(os.path.abspath(__file__))
os.system(f"cmake {cmake_args} -B {self.build_temp} -S {src_dir}")
ret = os.system(f"""
cmake --build {self.build_temp} --target _wenet --config Release
""")
if ret != 0:
raise Exception(
"\nBuild wenet failed. Please check the error message.\n"
"You can ask for help by creating an issue on GitHub.\n"
"\nClick:\n https://github.com/wenet-e2e/wenet/issues/new\n"
)
libs = []
for ext in ['so', 'pyd']:
libs.extend(
glob.glob(f"{self.build_temp}/**/_wenet*.{ext}",
recursive=True))
for ext in ['so', 'dylib', 'dll']:
libs.extend(
glob.glob(f"{self.build_temp}/**/*wenet_api.{ext}",
recursive=True))
for lib in libs:
print(f"Copying {lib} to {self.build_lib}/")
shutil.copy(f"{lib}", f"{self.build_lib}/")
def read_long_description():
with open("README.md", encoding="utf8") as f:
readme = f.read()
return readme
package_name = "wenetruntime"
setuptools.setup(
name=package_name,
version='1.0.12',
author="Binbin Zhang",
author_email="binbzha@qq.com",
package_dir={
package_name: "py",
},
packages=[package_name],
url="https://github.com/wenet-e2e/wenet",
long_description=read_long_description(),
long_description_content_type="text/markdown",
ext_modules=[cmake_extension("_wenet")],
cmdclass={"build_ext": BuildExtension},
zip_safe=False,
setup_requires=["tqdm"],
install_requires=["torch", "tqdm"],
classifiers=[
"Programming Language :: C++",
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
license="Apache licensed, as found in the LICENSE file",
)
../../core/utils
\ No newline at end of file
if(TORCH)
add_library(wenet_api SHARED wenet_api.cc)
target_link_libraries(wenet_api PUBLIC decoder)
endif()
# WeNet API
We refer [vosk](https://github.com/alphacep/vosk-api/blob/master/src/vosk_api.h)
for the interface design.
We are going to implement the following interfaces:
- [x] non-streaming recognition
- [] streaming recognition
- [] nbest
- [] contextual biasing word
- [] alignment
- [] language support(post processor)
- [] label check
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "api/wenet_api.h"
#include <memory>
#include <string>
#include <vector>
#include "decoder/asr_decoder.h"
#include "decoder/torch_asr_model.h"
#include "post_processor/post_processor.h"
#include "utils/file.h"
#include "utils/json.h"
#include "utils/string.h"
class Recognizer {
public:
explicit Recognizer(const std::string& model_dir) {
// FeaturePipeline init
feature_config_ = std::make_shared<wenet::FeaturePipelineConfig>(80, 16000);
feature_pipeline_ =
std::make_shared<wenet::FeaturePipeline>(*feature_config_);
// Resource init
resource_ = std::make_shared<wenet::DecodeResource>();
wenet::TorchAsrModel::InitEngineThreads();
std::string model_path = wenet::JoinPath(model_dir, "final.zip");
CHECK(wenet::FileExists(model_path));
auto model = std::make_shared<wenet::TorchAsrModel>();
model->Read(model_path);
resource_->model = model;
// units.txt: E2E model unit
std::string unit_path = wenet::JoinPath(model_dir, "units.txt");
CHECK(wenet::FileExists(unit_path));
resource_->unit_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(unit_path));
std::string fst_path = wenet::JoinPath(model_dir, "TLG.fst");
if (wenet::FileExists(fst_path)) { // With LM
resource_->fst = std::shared_ptr<fst::Fst<fst::StdArc>>(
fst::Fst<fst::StdArc>::Read(fst_path));
std::string symbol_path = wenet::JoinPath(model_dir, "words.txt");
CHECK(wenet::FileExists(symbol_path));
resource_->symbol_table = std::shared_ptr<fst::SymbolTable>(
fst::SymbolTable::ReadText(symbol_path));
} else { // Without LM, symbol_table is the same as unit_table
resource_->symbol_table = resource_->unit_table;
}
// Context config init
context_config_ = std::make_shared<wenet::ContextConfig>();
decode_options_ = std::make_shared<wenet::DecodeOptions>();
post_process_opts_ = std::make_shared<wenet::PostProcessOptions>();
}
void Reset() {
if (feature_pipeline_ != nullptr) {
feature_pipeline_->Reset();
}
if (decoder_ != nullptr) {
decoder_->Reset();
}
result_.clear();
}
void InitDecoder() {
CHECK(decoder_ == nullptr);
// Optional init context graph
if (context_.size() > 0) {
context_config_->context_score = context_score_;
auto context_graph =
std::make_shared<wenet::ContextGraph>(*context_config_);
context_graph->BuildContextGraph(context_, resource_->symbol_table);
resource_->context_graph = context_graph;
}
// PostProcessor
if (language_ == "chs") { // TODO(Binbin Zhang): CJK(chs, jp, kr)
post_process_opts_->language_type = wenet::kMandarinEnglish;
} else {
post_process_opts_->language_type = wenet::kIndoEuropean;
}
resource_->post_processor =
std::make_shared<wenet::PostProcessor>(*post_process_opts_);
// Init decoder
decoder_ = std::make_shared<wenet::AsrDecoder>(feature_pipeline_, resource_,
*decode_options_);
}
void Decode(const char* data, int len, int last) {
using wenet::DecodeState;
// Init decoder when it is called first time
if (decoder_ == nullptr) {
InitDecoder();
}
// Convert to 16 bits PCM data to float
CHECK_EQ(len % 2, 0);
feature_pipeline_->AcceptWaveform(reinterpret_cast<const int16_t*>(data),
len / 2);
if (last > 0) {
feature_pipeline_->set_input_finished();
}
while (true) {
DecodeState state = decoder_->Decode(false);
if (state == DecodeState::kWaitFeats) {
break;
} else if (state == DecodeState::kEndFeats) {
decoder_->Rescoring();
UpdateResult(true);
break;
} else if (state == DecodeState::kEndpoint && continuous_decoding_) {
decoder_->Rescoring();
UpdateResult(true);
decoder_->ResetContinuousDecoding();
} else { // kEndBatch
UpdateResult(false);
}
}
}
void UpdateResult(bool final_result) {
json::JSON obj;
obj["type"] = final_result ? "final_result" : "partial_result";
int nbest = final_result ? nbest_ : 1;
obj["nbest"] = json::Array();
for (int i = 0; i < nbest && i < decoder_->result().size(); i++) {
json::JSON one;
one["sentence"] = decoder_->result()[i].sentence;
if (final_result && enable_timestamp_) {
one["word_pieces"] = json::Array();
for (const auto& word_piece : decoder_->result()[i].word_pieces) {
json::JSON piece;
piece["word"] = word_piece.word;
piece["start"] = word_piece.start;
piece["end"] = word_piece.end;
one["word_pieces"].append(piece);
}
}
one["sentence"] = decoder_->result()[i].sentence;
obj["nbest"].append(one);
}
result_ = obj.dump();
}
const char* GetResult() { return result_.c_str(); }
void set_nbest(int n) { nbest_ = n; }
void set_enable_timestamp(bool flag) { enable_timestamp_ = flag; }
void AddContext(const char* word) { context_.emplace_back(word); }
void set_context_score(float score) { context_score_ = score; }
void set_language(const char* lang) { language_ = lang; }
void set_continuous_decoding(bool flag) { continuous_decoding_ = flag; }
private:
// NOTE(Binbin Zhang): All use shared_ptr for clone in the future
std::shared_ptr<wenet::FeaturePipelineConfig> feature_config_ = nullptr;
std::shared_ptr<wenet::FeaturePipeline> feature_pipeline_ = nullptr;
std::shared_ptr<wenet::DecodeResource> resource_ = nullptr;
std::shared_ptr<wenet::DecodeOptions> decode_options_ = nullptr;
std::shared_ptr<wenet::AsrDecoder> decoder_ = nullptr;
std::shared_ptr<wenet::ContextConfig> context_config_ = nullptr;
std::shared_ptr<wenet::PostProcessOptions> post_process_opts_ = nullptr;
int nbest_ = 1;
std::string result_;
bool enable_timestamp_ = false;
std::vector<std::string> context_;
float context_score_;
std::string language_ = "chs";
bool continuous_decoding_ = false;
};
void* wenet_init(const char* model_dir) {
Recognizer* decoder = new Recognizer(model_dir);
return reinterpret_cast<void*>(decoder);
}
void wenet_free(void* decoder) {
delete reinterpret_cast<Recognizer*>(decoder);
}
void wenet_reset(void* decoder) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->Reset();
}
void wenet_decode(void* decoder, const char* data, int len, int last) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->Decode(data, len, last);
}
const char* wenet_get_result(void* decoder) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
return recognizer->GetResult();
}
void wenet_set_log_level(int level) {
FLAGS_logtostderr = true;
FLAGS_v = level;
}
void wenet_set_nbest(void* decoder, int n) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_nbest(n);
}
void wenet_set_timestamp(void* decoder, int flag) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
bool enable = flag > 0 ? true : false;
recognizer->set_enable_timestamp(enable);
}
void wenet_add_context(void* decoder, const char* word) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->AddContext(word);
}
void wenet_set_context_score(void* decoder, float score) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_context_score(score);
}
void wenet_set_language(void* decoder, const char* lang) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_language(lang);
}
void wenet_set_continuous_decoding(void* decoder, int flag) {
Recognizer* recognizer = reinterpret_cast<Recognizer*>(decoder);
recognizer->set_continuous_decoding(flag > 0);
}
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef API_WENET_API_H_
#define API_WENET_API_H_
#ifdef __cplusplus
extern "C" {
#endif
/** Init decoder from the file and returns the object
*
* @param model_dir: the model dir
* @returns model object or NULL if problem occured
*/
void* wenet_init(const char* model_dir);
/** Free wenet decoder and corresponding resource
*/
void wenet_free(void* decoder);
/** Reset decoder for next decoding
*/
void wenet_reset(void* decoder);
/** Decode the input wav data
* @param data: pcm data, encoded as int16_t(16 bits)
* @param len: data length
* @param last: if it is the last package
*/
void wenet_decode(void* decoder, const char* data, int len, int last);
/** Get decode result in json format
* It returns partial result when last is 0
* It returns final result when last is 1
{
"nbest" : [{
"sentence" : "are you okay"
"word_pieces" : [{
"end" : 960,
"start" : 0,
"word" : "are"
}, {
"end" : 1200,
"start" : 960,
"word" : "you"
}, {
...}]
}, {
"sentence" : "are you ok"
}],
"type" : "final_result"
}
"type": final_result/partial_result
"nbest": nbest is enabled when n > 1 in final_result
"sentence": the ASR result
"word_pieces": optional, output timestamp when enabled
*/
const char* wenet_get_result(void* decoder);
/** Set n-best, range 1~10
* wenet_get_result will return top-n best results
*/
void wenet_set_nbest(void* decoder, int n);
/** Whether to enable word level timestamp in results
disable it when flag = 0, otherwise enable
*/
void wenet_set_timestamp(void* decoder, int flag);
/** Add one contextual biasing
*/
void wenet_add_context(void* decoder, const char* word);
/** Set contextual biasing bonus score
*/
void wenet_set_context_score(void* decoder, float score);
/** Set language, has effect on the postpocessing
* @param: lang, could be chs/en now
*/
void wenet_set_language(void* decoder, const char* lang);
/** Set log level
* We use glog in wenet, so the level is the glog level
*/
void wenet_set_log_level(int level);
/** Enable continous decoding or not
* flag > 0: enable, otherwise disable
*/
void wenet_set_continuous_decoding(void* decoder, int flag);
#ifdef __cplusplus
}
#endif
#endif // API_WENET_API_H_
add_executable(decoder_main decoder_main.cc)
target_link_libraries(decoder_main PUBLIC decoder)
add_executable(label_checker_main label_checker_main.cc)
target_link_libraries(label_checker_main PUBLIC decoder)
# if(TORCH)
# add_executable(api_main api_main.cc)
# target_link_libraries(api_main PUBLIC wenet_api)
# endif()
if(WEBSOCKET)
add_executable(websocket_client_main websocket_client_main.cc)
target_link_libraries(websocket_client_main PUBLIC websocket)
add_executable(websocket_server_main websocket_server_main.cc)
target_link_libraries(websocket_server_main PUBLIC websocket)
endif()
if(GRPC)
add_executable(grpc_server_main grpc_server_main.cc)
target_link_libraries(grpc_server_main PUBLIC wenet_grpc)
add_executable(grpc_client_main grpc_client_main.cc)
target_link_libraries(grpc_client_main PUBLIC wenet_grpc)
endif()
if(HTTP)
add_executable(http_client_main http_client_main.cc)
target_link_libraries(http_client_main PUBLIC http)
add_executable(http_server_main http_server_main.cc)
target_link_libraries(http_server_main PUBLIC http)
endif()
// Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "api/wenet_api.h"
#include "frontend/wav.h"
#include "utils/flags.h"
DEFINE_string(model_dir, "", "model dir path");
DEFINE_string(wav_path, "", "single wave path");
DEFINE_bool(enable_timestamp, false, "enable timestamps");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
wenet_set_log_level(2);
void* decoder = wenet_init(FLAGS_model_dir.c_str());
wenet_set_timestamp(decoder, FLAGS_enable_timestamp == true ? 1 : 0);
wenet::WavReader wav_reader(FLAGS_wav_path);
std::vector<int16_t> data(wav_reader.num_samples());
for (int i = 0; i < wav_reader.num_samples(); i++) {
data[i] = static_cast<int16_t>(*(wav_reader.data() + i));
}
for (int i = 0; i < 10; i++) {
// Return the final result when last is 1
wenet_decode(decoder, reinterpret_cast<const char*>(data.data()),
data.size() * 2, 1);
const char* result = wenet_get_result(decoder);
LOG(INFO) << i << " " << result;
wenet_reset(decoder);
}
wenet_free(decoder);
return 0;
}
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iomanip>
#include <thread>
#include <utility>
#include "decoder/params.h"
#include "frontend/wav.h"
#include "utils/flags.h"
#include "utils/string.h"
#include "utils/thread_pool.h"
#include "utils/timer.h"
#include "utils/utils.h"
DEFINE_bool(simulate_streaming, false, "simulate streaming input");
DEFINE_bool(output_nbest, false, "output n-best of decode result");
DEFINE_string(wav_path, "", "single wave path");
DEFINE_string(wav_scp, "", "input wav scp");
DEFINE_string(result, "", "result output file");
DEFINE_bool(continuous_decoding, false, "continuous decoding mode");
DEFINE_int32(thread_num, 1, "num of decode thread");
DEFINE_int32(warmup, 0, "num of warmup decode, 0 means no warmup");
std::shared_ptr<wenet::DecodeOptions> g_decode_config;
std::shared_ptr<wenet::FeaturePipelineConfig> g_feature_config;
std::shared_ptr<wenet::DecodeResource> g_decode_resource;
std::ofstream g_result;
std::mutex g_mutex;
int g_total_waves_dur = 0;
int g_total_decode_time = 0;
void decode(std::pair<std::string, std::string> wav, bool warmup = false) {
wenet::WavReader wav_reader(wav.second);
int num_samples = wav_reader.num_samples();
CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
auto feature_pipeline =
std::make_shared<wenet::FeaturePipeline>(*g_feature_config);
feature_pipeline->AcceptWaveform(wav_reader.data(), num_samples);
feature_pipeline->set_input_finished();
LOG(INFO) << "num frames " << feature_pipeline->num_frames();
wenet::AsrDecoder decoder(feature_pipeline, g_decode_resource,
*g_decode_config);
int wave_dur = static_cast<int>(static_cast<float>(num_samples) /
wav_reader.sample_rate() * 1000);
int decode_time = 0;
std::string final_result;
while (true) {
wenet::Timer timer;
wenet::DecodeState state = decoder.Decode();
if (state == wenet::DecodeState::kEndFeats) {
decoder.Rescoring();
}
int chunk_decode_time = timer.Elapsed();
decode_time += chunk_decode_time;
if (decoder.DecodedSomething()) {
LOG(INFO) << "Partial result: " << decoder.result()[0].sentence;
}
if (FLAGS_continuous_decoding && state == wenet::DecodeState::kEndpoint) {
if (decoder.DecodedSomething()) {
decoder.Rescoring();
LOG(INFO) << "Final result (continuous decoding): "
<< decoder.result()[0].sentence;
final_result.append(decoder.result()[0].sentence);
}
decoder.ResetContinuousDecoding();
}
if (state == wenet::DecodeState::kEndFeats) {
break;
} else if (FLAGS_chunk_size > 0 && FLAGS_simulate_streaming) {
float frame_shift_in_ms =
static_cast<float>(g_feature_config->frame_shift) /
wav_reader.sample_rate() * 1000;
auto wait_time =
decoder.num_frames_in_current_chunk() * frame_shift_in_ms -
chunk_decode_time;
if (wait_time > 0) {
LOG(INFO) << "Simulate streaming, waiting for " << wait_time << "ms";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(wait_time)));
}
}
}
if (decoder.DecodedSomething()) {
final_result.append(decoder.result()[0].sentence);
}
LOG(INFO) << wav.first << " Final result: " << final_result << std::endl;
LOG(INFO) << "Decoded " << wave_dur << "ms audio taken " << decode_time
<< "ms.";
if (!warmup) {
g_mutex.lock();
std::ostream& buffer = FLAGS_result.empty() ? std::cout : g_result;
if (!FLAGS_output_nbest) {
buffer << wav.first << " " << final_result << std::endl;
} else {
buffer << "wav " << wav.first << std::endl;
auto& results = decoder.result();
for (auto& r : results) {
if (r.sentence.empty()) continue;
buffer << "candidate " << r.score << " " << r.sentence << std::endl;
}
}
g_total_waves_dur += wave_dur;
g_total_decode_time += decode_time;
g_mutex.unlock();
}
}
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
g_decode_config = wenet::InitDecodeOptionsFromFlags();
g_feature_config = wenet::InitFeaturePipelineConfigFromFlags();
g_decode_resource = wenet::InitDecodeResourceFromFlags();
if (FLAGS_wav_path.empty() && FLAGS_wav_scp.empty()) {
LOG(FATAL) << "Please provide the wave path or the wav scp.";
}
std::vector<std::pair<std::string, std::string>> waves;
if (!FLAGS_wav_path.empty()) {
waves.emplace_back(make_pair("test", FLAGS_wav_path));
} else {
std::ifstream wav_scp(FLAGS_wav_scp);
std::string line;
while (getline(wav_scp, line)) {
std::vector<std::string> strs;
wenet::SplitString(line, &strs);
CHECK_GE(strs.size(), 2);
waves.emplace_back(make_pair(strs[0], strs[1]));
}
if (waves.empty()) {
LOG(FATAL) << "Please provide non-empty wav scp.";
}
}
if (!FLAGS_result.empty()) {
g_result.open(FLAGS_result, std::ios::out);
}
// Warmup
if (FLAGS_warmup > 0) {
LOG(INFO) << "Warming up...";
{
ThreadPool pool(FLAGS_thread_num);
auto wav = waves[0];
for (int i = 0; i < FLAGS_warmup; i++) {
pool.enqueue(decode, wav, true);
}
}
LOG(INFO) << "Warmup done.";
}
{
ThreadPool pool(FLAGS_thread_num);
for (auto& wav : waves) {
pool.enqueue(decode, wav, false);
}
}
LOG(INFO) << "Total: decoded " << g_total_waves_dur << "ms audio taken "
<< g_total_decode_time << "ms.";
LOG(INFO) << "RTF: " << std::setprecision(4)
<< static_cast<float>(g_total_decode_time) / g_total_waves_dur;
return 0;
}
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/wav.h"
#include "grpc/grpc_client.h"
#include "utils/flags.h"
#include "utils/timer.h"
DEFINE_string(hostname, "127.0.0.1", "hostname of websocket server");
DEFINE_int32(port, 10086, "port of websocket server");
DEFINE_int32(nbest, 1, "n-best of decode result");
DEFINE_string(wav_path, "", "test wav file path");
DEFINE_bool(continuous_decoding, false, "continuous decoding mode");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
wenet::GrpcClient client(FLAGS_hostname, FLAGS_port, FLAGS_nbest,
FLAGS_continuous_decoding);
wenet::WavReader wav_reader(FLAGS_wav_path);
const int sample_rate = 16000;
// Only support 16K
CHECK_EQ(wav_reader.sample_rate(), sample_rate);
const int num_samples = wav_reader.num_samples();
std::vector<float> pcm_data(wav_reader.data(),
wav_reader.data() + num_samples);
// Send data every 0.5 second
const float interval = 0.5;
const int sample_interval = interval * sample_rate;
for (int start = 0; start < num_samples; start += sample_interval) {
if (client.done()) {
break;
}
int end = std::min(start + sample_interval, num_samples);
// Convert to short
std::vector<int16_t> data;
data.reserve(end - start);
for (int j = start; j < end; j++) {
data.push_back(static_cast<int16_t>(pcm_data[j]));
}
// Send PCM data
client.SendBinaryData(data.data(), data.size() * sizeof(int16_t));
VLOG(2) << "Send " << data.size() << " samples";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(interval * 1000)));
}
wenet::Timer timer;
client.Join();
VLOG(2) << "Total latency: " << timer.Elapsed() << "ms.";
return 0;
}
// Copyright (c) 2021 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <grpcpp/ext/proto_server_reflection_plugin.h>
#include <grpcpp/grpcpp.h>
#include <grpcpp/health_check_service_interface.h>
#include "decoder/params.h"
#include "grpc/grpc_server.h"
#include "utils/log.h"
DEFINE_int32(port, 10086, "grpc listening port");
DEFINE_int32(workers, 4, "grpc num workers");
using grpc::Server;
using grpc::ServerBuilder;
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
auto decode_config = wenet::InitDecodeOptionsFromFlags();
auto feature_config = wenet::InitFeaturePipelineConfigFromFlags();
auto decode_resource = wenet::InitDecodeResourceFromFlags();
wenet::GrpcServer service(feature_config, decode_config, decode_resource);
grpc::EnableDefaultHealthCheckService(true);
grpc::reflection::InitProtoReflectionServerBuilderPlugin();
ServerBuilder builder;
std::string address("0.0.0.0:" + std::to_string(FLAGS_port));
builder.AddListeningPort(address, grpc::InsecureServerCredentials());
builder.RegisterService(&service);
builder.SetSyncServerOption(ServerBuilder::SyncServerOption::NUM_CQS,
FLAGS_workers);
std::unique_ptr<Server> server(builder.BuildAndStart());
LOG(INFO) << "Listening at port " << FLAGS_port;
server->Wait();
google::ShutdownGoogleLogging();
return 0;
}
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/wav.h"
#include "utils/flags.h"
#include "utils/timer.h"
#include "http/http_client.h"
DEFINE_string(hostname, "127.0.0.1", "hostname of http server");
DEFINE_int32(port, 10086, "port of http server");
DEFINE_int32(nbest, 1, "n-best of decode result");
DEFINE_string(wav_path, "", "test wav file path");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
wenet::WavReader wav_reader(FLAGS_wav_path);
const int sample_rate = 16000;
// Only support 16K
CHECK_EQ(wav_reader.sample_rate(), sample_rate);
const int num_samples = wav_reader.num_samples();
// Convert to short
std::vector<int16_t> data;
data.reserve(num_samples);
for (int j = 0; j < num_samples; j++) {
data.push_back(static_cast<int16_t>(wav_reader.data()[j]));
}
// Send data
wenet::HttpClient client(FLAGS_hostname, FLAGS_port);
client.set_nbest(FLAGS_nbest);
wenet::Timer timer;
VLOG(2) << "Send " << data.size() << " samples";
client.SendBinaryData(data.data(), data.size() * sizeof(int16_t));
VLOG(2) << "Total latency: " << timer.Elapsed() << "ms.";
return 0;
}
// Copyright (c) 2023 Ximalaya Speech Team (Xiang Lyu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/params.h"
#include "utils/log.h"
#include "http/http_server.h"
DEFINE_int32(port, 10086, "http listening port");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
auto decode_config = wenet::InitDecodeOptionsFromFlags();
auto feature_config = wenet::InitFeaturePipelineConfigFromFlags();
auto decode_resource = wenet::InitDecodeResourceFromFlags();
wenet::HttpServer server(FLAGS_port, feature_config, decode_config,
decode_resource);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
}
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <sstream>
#include <unordered_map>
#include <vector>
#include "decoder/params.h"
#include "frontend/wav.h"
#include "utils/flags.h"
#include "utils/string.h"
DEFINE_string(text, "", "kaldi style text input file");
DEFINE_string(wav_scp, "", "kaldi style wav scp");
DEFINE_double(is_penalty, 1.0,
"insertion/substitution penalty for align insertion");
DEFINE_double(del_penalty, 1.0, "deletion penalty for align insertion");
DEFINE_string(result, "", "result output file");
DEFINE_string(timestamp, "", "timestamp output file");
namespace wenet {
const char* kDeletion = "<del>";
// Is: Insertion and substitution
const char* kIsStart = "<is>";
const char* kIsEnd = "</is>";
bool MapToLabel(const std::string& text,
std::shared_ptr<fst::SymbolTable> symbol_table,
std::vector<int>* labels) {
labels->clear();
// Split label to char sequence
std::vector<std::string> chars;
SplitUTF8StringToChars(text, &chars);
for (size_t i = 0; i < chars.size(); i++) {
// ▁ is special symbol for white space
std::string label = chars[i] != " " ? chars[i] : "▁";
int id = symbol_table->Find(label);
if (id != -1) { // fst::kNoSymbol
// LOG(INFO) << label << " " << id;
labels->push_back(id);
}
}
return true;
}
std::shared_ptr<fst::SymbolTable> MakeSymbolTableForFst(
std::shared_ptr<fst::SymbolTable> isymbol_table) {
LOG(INFO) << isymbol_table;
CHECK(isymbol_table != nullptr);
auto osymbol_table = std::make_shared<fst::SymbolTable>();
osymbol_table->AddSymbol("<eps>", 0);
CHECK_EQ(isymbol_table->Find("<blank>"), 0);
osymbol_table->AddSymbol("<blank>", 1);
for (int i = 1; i < isymbol_table->NumSymbols(); i++) {
std::string symbol = isymbol_table->Find(i);
osymbol_table->AddSymbol(symbol, i + 1);
}
osymbol_table->AddSymbol(kDeletion, isymbol_table->NumSymbols() + 1);
osymbol_table->AddSymbol(kIsStart, isymbol_table->NumSymbols() + 2);
osymbol_table->AddSymbol(kIsEnd, isymbol_table->NumSymbols() + 3);
return osymbol_table;
}
void CompileCtcFst(std::shared_ptr<fst::SymbolTable> symbol_table,
fst::StdVectorFst* ofst) {
ofst->DeleteStates();
int start = ofst->AddState();
ofst->SetStart(start);
CHECK_EQ(symbol_table->Find("<eps>"), 0);
CHECK_EQ(symbol_table->Find("<blank>"), 1);
ofst->AddArc(start, fst::StdArc(1, 0, 0.0, start));
// Exclude kDeletion and kInsertion
for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
int s = ofst->AddState();
ofst->AddArc(start, fst::StdArc(i, i, 0.0, s));
ofst->AddArc(s, fst::StdArc(i, 0, 0.0, s));
ofst->AddArc(s, fst::StdArc(0, 0, 0.0, start));
}
ofst->SetFinal(start, fst::StdArc::Weight::One());
fst::ArcSort(ofst, fst::StdOLabelCompare());
}
void CompileAlignFst(std::vector<int> labels,
std::shared_ptr<fst::SymbolTable> symbol_table,
fst::StdVectorFst* ofst) {
ofst->DeleteStates();
int deletion = symbol_table->Find(kDeletion);
int insertion_start = symbol_table->Find(kIsStart);
int insertion_end = symbol_table->Find(kIsEnd);
int start = ofst->AddState();
ofst->SetStart(start);
// Filler State
int filler_start = ofst->AddState();
int filler_end = ofst->AddState();
for (int i = 2; i < symbol_table->NumSymbols() - 3; i++) {
ofst->AddArc(filler_start, fst::StdArc(i, i, FLAGS_is_penalty, filler_end));
}
ofst->AddArc(filler_end, fst::StdArc(0, 0, 0.0, filler_start));
int prev = start;
// Alignment path and optional filler
for (size_t i = 0; i < labels.size(); i++) {
int cur = ofst->AddState();
// 1. Insertion or Substitution
ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
// 2. Correct
ofst->AddArc(prev, fst::StdArc(labels[i], labels[i], 0.0, cur));
// 3. Deletion
ofst->AddArc(prev, fst::StdArc(0, deletion, FLAGS_del_penalty, cur));
prev = cur;
}
// Optional add endding filler
ofst->AddArc(prev, fst::StdArc(0, insertion_start, 0.0, filler_start));
ofst->AddArc(filler_end, fst::StdArc(0, insertion_end, 0.0, prev));
ofst->SetFinal(prev, fst::StdArc::Weight::One());
fst::ArcSort(ofst, fst::StdILabelCompare());
}
} // namespace wenet
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
auto decode_config = wenet::InitDecodeOptionsFromFlags();
auto feature_config = wenet::InitFeaturePipelineConfigFromFlags();
auto decode_resource = wenet::InitDecodeResourceFromFlags();
CHECK(decode_resource->unit_table != nullptr);
auto wfst_symbol_table =
wenet::MakeSymbolTableForFst(decode_resource->unit_table);
// wfst_symbol_table->WriteText("fst.txt");
// Reset symbol_table to on-the-fly generated wfst_symbol_table
decode_resource->symbol_table = wfst_symbol_table;
// Compile ctc FST
fst::StdVectorFst ctc_fst;
wenet::CompileCtcFst(wfst_symbol_table, &ctc_fst);
// ctc_fst.Write("ctc.fst");
std::unordered_map<std::string, std::string> wav_table;
std::ifstream wav_is(FLAGS_wav_scp);
std::string line;
while (std::getline(wav_is, line)) {
std::vector<std::string> strs;
wenet::SplitString(line, &strs);
CHECK_EQ(strs.size(), 2);
wav_table[strs[0]] = strs[1];
}
std::ifstream text_is(FLAGS_text);
std::ofstream result_os(FLAGS_result, std::ios::out);
std::ofstream timestamp_out;
if (!FLAGS_timestamp.empty()) {
timestamp_out.open(FLAGS_timestamp, std::ios::out);
}
std::ostream& timestamp_os =
FLAGS_timestamp.empty() ? std::cout : timestamp_out;
while (std::getline(text_is, line)) {
std::vector<std::string> strs;
wenet::SplitString(line, &strs);
if (strs.size() < 2) continue;
std::string key = strs[0];
LOG(INFO) << "Processing " << key;
if (wav_table.find(key) != wav_table.end()) {
strs.erase(strs.begin());
std::string text = wenet::JoinString(" ", strs);
std::vector<int> labels;
wenet::MapToLabel(text, wfst_symbol_table, &labels);
// Prepare FST for alignment decoding
fst::StdVectorFst align_fst;
wenet::CompileAlignFst(labels, wfst_symbol_table, &align_fst);
// align_fst.Write("align.fst");
auto decoding_fst = std::make_shared<fst::StdVectorFst>();
fst::Compose(ctc_fst, align_fst, decoding_fst.get());
// decoding_fst->Write("decoding.fst");
// Preapre feature pipeline
wenet::WavReader wav_reader;
if (!wav_reader.Open(wav_table[key])) {
LOG(WARNING) << "Error in reading " << wav_table[key];
continue;
}
int num_samples = wav_reader.num_samples();
CHECK_EQ(wav_reader.sample_rate(), FLAGS_sample_rate);
auto feature_pipeline =
std::make_shared<wenet::FeaturePipeline>(*feature_config);
feature_pipeline->AcceptWaveform(wav_reader.data(), num_samples);
feature_pipeline->set_input_finished();
decode_resource->fst = decoding_fst;
LOG(INFO) << "num frames " << feature_pipeline->num_frames();
wenet::AsrDecoder decoder(feature_pipeline, decode_resource,
*decode_config);
while (true) {
wenet::DecodeState state = decoder.Decode();
if (state == wenet::DecodeState::kEndFeats) {
decoder.Rescoring();
break;
}
}
std::string final_result;
std::string timestamp_str;
if (decoder.DecodedSomething()) {
const wenet::DecodeResult& result = decoder.result()[0];
final_result = result.sentence;
std::stringstream ss;
for (const auto& w : result.word_pieces) {
ss << " " << w.word << " " << w.start << " " << w.end;
}
timestamp_str = ss.str();
}
result_os << key << " " << final_result << std::endl;
timestamp_os << key << " " << timestamp_str << std::endl;
LOG(INFO) << key << " " << final_result;
} else {
LOG(WARNING) << "No wav file for " << key;
}
}
return 0;
}
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "frontend/wav.h"
#include "utils/flags.h"
#include "utils/timer.h"
#include "websocket/websocket_client.h"
DEFINE_string(hostname, "127.0.0.1", "hostname of websocket server");
DEFINE_int32(port, 10086, "port of websocket server");
DEFINE_int32(nbest, 1, "n-best of decode result");
DEFINE_string(wav_path, "", "test wav file path");
DEFINE_bool(continuous_decoding, false, "continuous decoding mode");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
wenet::WebSocketClient client(FLAGS_hostname, FLAGS_port);
client.set_nbest(FLAGS_nbest);
client.set_continuous_decoding(FLAGS_continuous_decoding);
client.SendStartSignal();
wenet::WavReader wav_reader(FLAGS_wav_path);
const int sample_rate = 16000;
// Only support 16K
CHECK_EQ(wav_reader.sample_rate(), sample_rate);
const int num_samples = wav_reader.num_samples();
// Send data every 0.5 second
const float interval = 0.5;
const int sample_interval = interval * sample_rate;
for (int start = 0; start < num_samples; start += sample_interval) {
if (client.done()) {
break;
}
int end = std::min(start + sample_interval, num_samples);
// Convert to short
std::vector<int16_t> data;
data.reserve(end - start);
for (int j = start; j < end; j++) {
data.push_back(static_cast<int16_t>(wav_reader.data()[j]));
}
// TODO(Binbin Zhang): Network order?
// Send PCM data
client.SendBinaryData(data.data(), data.size() * sizeof(int16_t));
VLOG(2) << "Send " << data.size() << " samples";
std::this_thread::sleep_for(
std::chrono::milliseconds(static_cast<int>(interval * 1000)));
}
wenet::Timer timer;
client.SendEndSignal();
client.Join();
VLOG(2) << "Total latency: " << timer.Elapsed() << "ms.";
return 0;
}
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/params.h"
#include "utils/log.h"
#include "websocket/websocket_server.h"
DEFINE_int32(port, 10086, "websocket listening port");
int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
auto decode_config = wenet::InitDecodeOptionsFromFlags();
auto feature_config = wenet::InitFeaturePipelineConfigFromFlags();
auto decode_resource = wenet::InitDecodeResourceFromFlags();
wenet::WebSocketServer server(FLAGS_port, feature_config, decode_config,
decode_resource);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
}
FetchContent_Declare(boost
URL https://boostorg.jfrog.io/artifactory/main/release/1.75.0/source/boost_1_75_0.tar.gz
URL_HASH SHA256=aeb26f80e80945e82ee93e5939baebdca47b9dee80a07d3144be1e1a6a66dd6a
)
FetchContent_MakeAvailable(boost)
include_directories(${boost_SOURCE_DIR})
if(MSVC)
add_definitions(-DBOOST_ALL_DYN_LINK -DBOOST_ALL_NO_LIB)
endif()
\ No newline at end of file
if(BPU)
if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(EASY_DNN_URL "https://github.com/xingchensong/toolchain_pkg/releases/download/easy_dnn/easy_dnn.0.4.11.tar.gz")
set(URL_HASH "SHA256=a1a6f77d1baae7181d75ec5d37a2ee529ac4e1c4400babd6ceb1c007392a4904")
else()
message(FATAL_ERROR "Unsupported CMake System Processor '${CMAKE_SYSTEM_PROCESSOR}' (expected 'aarch64')")
endif()
else()
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Linux')")
endif()
FetchContent_Declare(easy_dnn
URL ${EASY_DNN_URL}
URL_HASH ${URL_HASH}
)
FetchContent_MakeAvailable(easy_dnn)
include_directories(${easy_dnn_SOURCE_DIR}/easy_dnn/0.4.11_linux_aarch64-j3_hobot_gcc6.5.0/files/easy_dnn/include)
include_directories(${easy_dnn_SOURCE_DIR}/dnn/1.7.0_linux_aarch64-j3_hobot_gcc6.5.0/files/dnn/include)
include_directories(${easy_dnn_SOURCE_DIR}/hlog/0.4.7_linux_aarch64-j3_hobot_gcc6.5.0/files/hlog/include)
link_directories(${easy_dnn_SOURCE_DIR}/easy_dnn/0.4.11_linux_aarch64-j3_hobot_gcc6.5.0/files/easy_dnn/lib)
link_directories(${easy_dnn_SOURCE_DIR}/dnn/1.7.0_linux_aarch64-j3_hobot_gcc6.5.0/files/dnn/lib)
link_directories(${easy_dnn_SOURCE_DIR}/hlog/0.4.7_linux_aarch64-j3_hobot_gcc6.5.0/files/hlog/lib)
add_definitions(-DUSE_BPU)
# NOTE(xcsong): Reasons for adding flag `-fuse-ld=gold`:
# https://stackoverflow.com/questions/59915966/unknown-gcc-linker-error-but-builds-sucessfully/59916438#59916438
# https://github.com/tensorflow/tensorflow/issues/47849
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fuse-ld=gold")
endif()
FetchContent_Declare(gflags
URL https://github.com/gflags/gflags/archive/v2.2.2.zip
URL_HASH SHA256=19713a36c9f32b33df59d1c79b4958434cb005b5b47dc5400a7a4b078111d9b5
)
FetchContent_MakeAvailable(gflags)
include_directories(${gflags_BINARY_DIR}/include)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment