Commit 764b3a75 authored by Sugon_ldc's avatar Sugon_ldc
Browse files

add new model

parents
FetchContent_Declare(glog
URL https://github.com/google/glog/archive/v0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
)
FetchContent_MakeAvailable(glog)
include_directories(${glog_SOURCE_DIR}/src ${glog_BINARY_DIR})
\ No newline at end of file
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/grpc)
# third_party: grpc
# On how to build grpc, you may refer to https://github.com/grpc/grpc
# We recommend manually recursive clone the repo to avoid internet connection problem
FetchContent_Declare(gRPC
GIT_REPOSITORY https://github.com/grpc/grpc
GIT_TAG v1.37.1
)
FetchContent_MakeAvailable(gRPC)
\ No newline at end of file
FetchContent_Declare(googletest
URL https://github.com/google/googletest/archive/release-1.11.0.zip
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
)
if(MSVC)
set(gtest_force_shared_crt ON CACHE BOOL "Always use msvcrt.dll" FORCE)
endif()
FetchContent_MakeAvailable(googletest)
\ No newline at end of file
if(TORCH)
add_definitions(-DUSE_TORCH)
if(NOT ANDROID)
if(GPU)
if (NOT ${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
message(FATAL_ERROR "GPU is supported only Linux, you can use CPU version")
else()
add_definitions(-DUSE_GPU)
endif()
endif()
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
if(${CMAKE_BUILD_TYPE} MATCHES "Release")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-1.13.0%2Bcpu.zip")
set(URL_HASH "SHA256=bece54d36377990257e9d028c687c5b6759c5cfec0a0153da83cf6f0f71f648f")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-debug-1.13.0%2Bcpu.zip")
set(URL_HASH "SHA256=3cc7ba3c3865d86f03d78c2f0878fdbed8b764359476397a5c95cf3bba0d665a")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CXX11_ABI)
if(NOT GPU)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcpu.zip")
set(URL_HASH "SHA256=d52f63577a07adb0bfd6d77c90f7da21896e94f71eb7dcd55ed7835ccb3b2b59")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu113.zip")
set(URL_HASH "SHA256=80f089939de20e68e3fcad4dfa72a26c8bf91b5e77b11042f671f39ebac35865")
endif()
else()
if(NOT GPU)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip")
set(URL_HASH "SHA256=bee1b7be308792aa60fc95a4f5274d9658cb7248002d0e333d49eb81ec88430c")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.11.0%2Bcu113.zip")
set(URL_HASH "SHA256=90159ecce3ff451f3ef3f657493b6c7c96759c3b74bbd70c1695f2ea2f81e1ad")
endif()
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-1.13.0.zip")
set(URL_HASH "SHA256=a8f80050b95489b4e002547910410c2c230e9f590ffab2482e19e809afe4f7aa")
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
add_definitions(-DIOS)
else()
message(FATAL_ERROR "Unsupported System '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux', 'Darwin' or 'iOS')")
endif()
# iOS use LibTorch from pod install
if(NOT IOS)
FetchContent_Declare(libtorch
URL ${LIBTORCH_URL}
URL_HASH ${URL_HASH}
)
FetchContent_MakeAvailable(libtorch)
find_package(Torch REQUIRED PATHS ${libtorch_SOURCE_DIR} NO_DEFAULT_PATH)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS} -DC10_USE_GLOG")
endif()
if(MSVC)
file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
file(COPY ${TORCH_DLLS} DESTINATION ${CMAKE_BINARY_DIR})
endif()
else()
# Change version in runtime/android/app/build.gradle.
file(GLOB PYTORCH_INCLUDE_DIRS "${build_DIR}/pytorch_android*.aar/headers")
file(GLOB PYTORCH_LINK_DIRS "${build_DIR}/pytorch_android*.aar/jni/${ANDROID_ABI}")
find_library(PYTORCH_LIBRARY pytorch_jni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH
)
find_library(FBJNI_LIBRARY fbjni
PATHS ${PYTORCH_LINK_DIRS}
NO_CMAKE_FIND_ROOT_PATH
)
include_directories(
${PYTORCH_INCLUDE_DIRS}
${PYTORCH_INCLUDE_DIRS}/torch/csrc/api/include
)
endif()
endif()
if(ONNX)
set(ONNX_VERSION "1.12.0")
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-win-x64-${ONNX_VERSION}.zip")
set(URL_HASH "SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176")
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac")
else()
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-linux-x64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
set(ONNX_URL "https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/onnxruntime-osx-x86_64-${ONNX_VERSION}.tgz")
set(URL_HASH "SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600")
else()
message(FATAL_ERROR "Unsupported CMake System Name '${CMAKE_SYSTEM_NAME}' (expected 'Windows', 'Linux' or 'Darwin')")
endif()
FetchContent_Declare(onnxruntime
URL ${ONNX_URL}
URL_HASH ${URL_HASH}
)
FetchContent_MakeAvailable(onnxruntime)
include_directories(${onnxruntime_SOURCE_DIR}/include)
link_directories(${onnxruntime_SOURCE_DIR}/lib)
if(MSVC)
file(GLOB ONNX_DLLS "${onnxruntime_SOURCE_DIR}/lib/*.dll")
file(COPY ${ONNX_DLLS} DESTINATION ${CMAKE_BINARY_DIR}/bin/${CMAKE_BUILD_TYPE})
endif()
add_definitions(-DUSE_ONNX)
endif()
if(NOT ANDROID)
include(gflags)
# We can't build glog with gflags, unless gflags is pre-installed.
# If build glog with pre-installed gflags, there will be conflict.
set(WITH_GFLAGS OFF CACHE BOOL "whether build glog with gflags" FORCE)
include(glog)
if(NOT GRAPH_TOOLS)
set(HAVE_BIN OFF CACHE BOOL "Build the fst binaries" FORCE)
set(HAVE_SCRIPT OFF CACHE BOOL "Build the fstscript" FORCE)
endif()
set(HAVE_COMPACT OFF CACHE BOOL "Build compact" FORCE)
set(HAVE_CONST OFF CACHE BOOL "Build const" FORCE)
set(HAVE_GRM OFF CACHE BOOL "Build grm" FORCE)
set(HAVE_FAR OFF CACHE BOOL "Build far" FORCE)
set(HAVE_PDT OFF CACHE BOOL "Build pdt" FORCE)
set(HAVE_MPDT OFF CACHE BOOL "Build mpdt" FORCE)
set(HAVE_LINEAR OFF CACHE BOOL "Build linear" FORCE)
set(HAVE_LOOKAHEAD OFF CACHE BOOL "Build lookahead" FORCE)
set(HAVE_NGRAM OFF CACHE BOOL "Build ngram" FORCE)
set(HAVE_SPECIAL OFF CACHE BOOL "Build special" FORCE)
if(MSVC)
add_compile_options(/W0 /wd4244 /wd4267)
endif()
# "OpenFST port for Windows" builds openfst with cmake for multiple platforms.
# Openfst is compiled with glog/gflags to avoid log and flag conflicts with log and flags in wenet/libtorch.
# To build openfst with gflags and glog, we comment out some vars of {flags, log}.h and flags.cc.
set(openfst_SOURCE_DIR ${fc_base}/openfst-src CACHE PATH "OpenFST source directory")
FetchContent_Declare(openfst
URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.6.5.1.tar.gz
URL_HASH SHA256=02c49b559c3976a536876063369efc0e41ab374be1035918036474343877046e
PATCH_COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/patch/openfst ${openfst_SOURCE_DIR}
)
FetchContent_MakeAvailable(openfst)
add_dependencies(fst gflags glog)
target_link_libraries(fst PUBLIC gflags_nothreads_static glog)
include_directories(${openfst_SOURCE_DIR}/src/include)
else()
set(openfst_BINARY_DIR ${build_DIR}/wenet-openfst-android-1.0.2.aar/jni)
include_directories(${openfst_BINARY_DIR}/include)
link_directories(${openfst_BINARY_DIR}/${ANDROID_ABI})
link_libraries(log gflags_nothreads glog fst)
endif()
FetchContent_Declare(pybind11
URL https://github.com/pybind/pybind11/archive/refs/tags/v2.9.2.zip
URL_HASH SHA256=d1646e6f70d8a3acb2ddd85ce1ed543b5dd579c68b8fb8e9638282af20edead8
)
FetchContent_MakeAvailable(pybind11)
add_subdirectory(${pybind11_SOURCE_DIR})
\ No newline at end of file
if(NOT WIN32)
string(ASCII 27 Esc)
set(ColourReset "${Esc}[m")
set(ColourBold "${Esc}[1m")
set(Red "${Esc}[31m")
set(Green "${Esc}[32m")
set(Yellow "${Esc}[33m")
set(Blue "${Esc}[34m")
set(Magenta "${Esc}[35m")
set(Cyan "${Esc}[36m")
set(White "${Esc}[37m")
set(BoldRed "${Esc}[1;31m")
set(BoldGreen "${Esc}[1;32m")
set(BoldYellow "${Esc}[1;33m")
set(BoldBlue "${Esc}[1;34m")
set(BoldMagenta "${Esc}[1;35m")
set(BoldCyan "${Esc}[1;36m")
set(BoldWhite "${Esc}[1;37m")
endif()
if(XPU)
set(RUNTIME_KUNLUN_PATH ${CMAKE_CURRENT_SOURCE_DIR})
message(STATUS "RUNTIME_KUNLUN_PATH is ${RUNTIME_KUNLUN_PATH} .\n")
set(KUNLUN_XPU_PATH ${RUNTIME_KUNLUN_PATH}/xpu)
if(NOT DEFINED ENV{XPU_API_PATH})
message(FATAL_ERROR "${BoldRed}NO ENV{XPU_API_PATH} in your env. Please set XPU_API_PATH.${ColourReset}\n")
else()
set(XPU_API_PATH $ENV{XPU_API_PATH})
message("set XPU_API_PATH from env_var. Val is $ENV{XPU_API_PATH}.")
endif()
include_directories(${RUNTIME_KUNLUN_PATH} ${KUNLUN_XPU_PATH}/
${XPU_API_PATH}/output/include ${XPU_API_PATH}/../runtime/include)
link_directories(${XPU_API_PATH}/output/so/ ${XPU_API_PATH}/../runtime/output/so/)
add_definitions(-DUSE_XPU)
endif()
set(decoder_srcs
asr_decoder.cc
asr_model.cc
context_graph.cc
ctc_prefix_beam_search.cc
ctc_wfst_beam_search.cc
ctc_endpoint.cc
)
if(NOT TORCH AND NOT ONNX AND NOT XPU AND NOT IOS AND NOT BPU)
message(FATAL_ERROR "Please build with TORCH or ONNX or XPU or IOS or BPU!!!")
endif()
if(TORCH OR IOS)
list(APPEND decoder_srcs torch_asr_model.cc)
endif()
if(ONNX)
list(APPEND decoder_srcs onnx_asr_model.cc)
endif()
add_library(decoder STATIC ${decoder_srcs})
target_link_libraries(decoder PUBLIC kaldi-decoder frontend
post_processor utils)
if(ANDROID)
target_link_libraries(decoder PUBLIC ${PYTORCH_LIBRARY} ${FBJNI_LIBRARY})
else()
if(TORCH)
target_link_libraries(decoder PUBLIC ${TORCH_LIBRARIES})
endif()
if(ONNX)
target_link_libraries(decoder PUBLIC onnxruntime)
endif()
if(BPU)
target_link_libraries(decoder PUBLIC bpu_asr_model)
endif()
if(XPU)
target_link_libraries(decoder PUBLIC xpu_conformer)
endif()
endif()
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/asr_decoder.h"
#include <ctype.h>
#include <algorithm>
#include <limits>
#include <utility>
#include "utils/timer.h"
namespace wenet {
AsrDecoder::AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
std::shared_ptr<DecodeResource> resource,
const DecodeOptions& opts)
: feature_pipeline_(std::move(feature_pipeline)),
// Make a copy of the model ASR model since we will change the inner
// status of the model
model_(resource->model->Copy()),
post_processor_(resource->post_processor),
symbol_table_(resource->symbol_table),
fst_(resource->fst),
unit_table_(resource->unit_table),
opts_(opts),
ctc_endpointer_(new CtcEndpoint(opts.ctc_endpoint_config)) {
if (opts_.reverse_weight > 0) {
// Check if model has a right to left decoder
CHECK(model_->is_bidirectional_decoder());
}
if (nullptr == fst_) {
searcher_.reset(new CtcPrefixBeamSearch(opts.ctc_prefix_search_opts,
resource->context_graph));
} else {
searcher_.reset(new CtcWfstBeamSearch(*fst_, opts.ctc_wfst_search_opts,
resource->context_graph));
}
ctc_endpointer_->frame_shift_in_ms(frame_shift_in_ms());
}
void AsrDecoder::Reset() {
start_ = false;
result_.clear();
num_frames_ = 0;
global_frame_offset_ = 0;
model_->Reset();
searcher_->Reset();
feature_pipeline_->Reset();
ctc_endpointer_->Reset();
}
void AsrDecoder::ResetContinuousDecoding() {
global_frame_offset_ = num_frames_;
start_ = false;
result_.clear();
model_->Reset();
searcher_->Reset();
ctc_endpointer_->Reset();
}
DecodeState AsrDecoder::Decode(bool block) {
return this->AdvanceDecoding(block);
}
void AsrDecoder::Rescoring() {
// Do attention rescoring
Timer timer;
AttentionRescoring();
VLOG(2) << "Rescoring cost latency: " << timer.Elapsed() << "ms.";
}
DecodeState AsrDecoder::AdvanceDecoding(bool block) {
DecodeState state = DecodeState::kEndBatch;
model_->set_chunk_size(opts_.chunk_size);
model_->set_num_left_chunks(opts_.num_left_chunks);
int num_required_frames = model_->num_frames_for_chunk(start_);
std::vector<std::vector<float>> chunk_feats;
// Return immediately if we do not want to block
if (!block && !feature_pipeline_->input_finished() &&
feature_pipeline_->NumQueuedFrames() < num_required_frames) {
return DecodeState::kWaitFeats;
}
// If not okay, that means we reach the end of the input
if (!feature_pipeline_->Read(num_required_frames, &chunk_feats)) {
state = DecodeState::kEndFeats;
}
num_frames_ += chunk_feats.size();
VLOG(2) << "Required " << num_required_frames << " get "
<< chunk_feats.size();
Timer timer;
std::vector<std::vector<float>> ctc_log_probs;
model_->ForwardEncoder(chunk_feats, &ctc_log_probs);
int forward_time = timer.Elapsed();
if (opts_.ctc_wfst_search_opts.blank_scale != 1.0) {
for (int i = 0; i < ctc_log_probs.size(); i++) {
ctc_log_probs[i][0] = ctc_log_probs[i][0]
+ std::log(opts_.ctc_wfst_search_opts.blank_scale);
}
}
timer.Reset();
searcher_->Search(ctc_log_probs);
int search_time = timer.Elapsed();
VLOG(3) << "forward takes " << forward_time << " ms, search takes "
<< search_time << " ms";
UpdateResult();
if (state != DecodeState::kEndFeats) {
if (ctc_endpointer_->IsEndpoint(ctc_log_probs, DecodedSomething())) {
VLOG(1) << "Endpoint is detected at " << num_frames_;
state = DecodeState::kEndpoint;
}
}
start_ = true;
return state;
}
void AsrDecoder::UpdateResult(bool finish) {
const auto& hypotheses = searcher_->Outputs();
const auto& inputs = searcher_->Inputs();
const auto& likelihood = searcher_->Likelihood();
const auto& times = searcher_->Times();
result_.clear();
CHECK_EQ(hypotheses.size(), likelihood.size());
for (size_t i = 0; i < hypotheses.size(); i++) {
const std::vector<int>& hypothesis = hypotheses[i];
DecodeResult path;
path.score = likelihood[i];
int offset = global_frame_offset_ * feature_frame_shift_in_ms();
for (size_t j = 0; j < hypothesis.size(); j++) {
std::string word = symbol_table_->Find(hypothesis[j]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if (searcher_->Type() == kWfstBeamSearch) {
path.sentence += (' ' + word);
} else {
path.sentence += (word);
}
}
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we use
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training.
if (unit_table_ != nullptr && finish) {
const std::vector<int>& input = inputs[i];
const std::vector<int>& time_stamp = times[i];
CHECK_EQ(input.size(), time_stamp.size());
for (size_t j = 0; j < input.size(); j++) {
std::string word = unit_table_->Find(input[j]);
int start = time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_ > 0
? time_stamp[j] * frame_shift_in_ms() - time_stamp_gap_
: 0;
if (j > 0) {
start = (time_stamp[j] - time_stamp[j - 1]) * frame_shift_in_ms() <
time_stamp_gap_
? (time_stamp[j - 1] + time_stamp[j]) / 2 *
frame_shift_in_ms()
: start;
}
int end = time_stamp[j] * frame_shift_in_ms();
if (j < input.size() - 1) {
end = (time_stamp[j + 1] - time_stamp[j]) * frame_shift_in_ms() <
time_stamp_gap_
? (time_stamp[j + 1] + time_stamp[j]) / 2 *
frame_shift_in_ms()
: end;
}
WordPiece word_piece(word, offset + start, offset + end);
path.word_pieces.emplace_back(word_piece);
}
}
if (post_processor_ != nullptr) {
path.sentence = post_processor_->Process(path.sentence, finish);
}
result_.emplace_back(path);
}
if (DecodedSomething()) {
VLOG(1) << "Partial CTC result " << result_[0].sentence;
}
}
void AsrDecoder::AttentionRescoring() {
searcher_->FinalizeSearch();
UpdateResult(true);
// No need to do rescoring
if (0.0 == opts_.rescoring_weight) {
return;
}
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const auto& hypotheses = searcher_->Inputs();
int num_hyps = hypotheses.size();
if (num_hyps <= 0) {
return;
}
std::vector<float> rescoring_score;
model_->AttentionRescoring(hypotheses, opts_.reverse_weight,
&rescoring_score);
// Combine ctc score and rescoring score
for (size_t i = 0; i < num_hyps; ++i) {
result_[i].score = opts_.rescoring_weight * rescoring_score[i] +
opts_.ctc_weight * result_[i].score;
}
std::sort(result_.begin(), result_.end(), DecodeResult::CompareFunc);
}
} // namespace wenet
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_ASR_DECODER_H_
#define DECODER_ASR_DECODER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "decoder/asr_model.h"
#include "decoder/context_graph.h"
#include "decoder/ctc_endpoint.h"
#include "decoder/ctc_prefix_beam_search.h"
#include "decoder/ctc_wfst_beam_search.h"
#include "decoder/search_interface.h"
#include "frontend/feature_pipeline.h"
#include "post_processor/post_processor.h"
#include "utils/utils.h"
namespace wenet {
struct DecodeOptions {
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 64 = 16*4
int chunk_size = 16;
int num_left_chunks = -1;
// final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores in the following two search
// methods are different.
// For CtcPrefixBeamSearch, it's a sum(prefix) score + context score
// For CtcWfstBeamSearch, it's a max(viterbi) path score + context score
// So we should carefully set ctc_weight according to the search methods.
float ctc_weight = 0.5;
float rescoring_weight = 1.0;
float reverse_weight = 0.0;
CtcEndpointConfig ctc_endpoint_config;
CtcPrefixBeamSearchOptions ctc_prefix_search_opts;
CtcWfstBeamSearchOptions ctc_wfst_search_opts;
};
struct WordPiece {
std::string word;
int start = -1;
int end = -1;
WordPiece(std::string word, int start, int end)
: word(std::move(word)), start(start), end(end) {}
};
struct DecodeResult {
float score = -kFloatMax;
std::string sentence;
std::vector<WordPiece> word_pieces;
static bool CompareFunc(const DecodeResult& a, const DecodeResult& b) {
return a.score > b.score;
}
};
enum DecodeState {
kEndBatch = 0x00, // End of current decoding batch, normal case
kEndpoint = 0x01, // Endpoint is detected
kEndFeats = 0x02, // All feature is decoded
kWaitFeats = 0x03 // Feat is not enough for one chunk inference, wait
};
// DecodeResource is thread safe, which can be shared for multiple
// decoding threads
struct DecodeResource {
std::shared_ptr<AsrModel> model = nullptr;
std::shared_ptr<fst::SymbolTable> symbol_table = nullptr;
std::shared_ptr<fst::Fst<fst::StdArc>> fst = nullptr;
std::shared_ptr<fst::SymbolTable> unit_table = nullptr;
std::shared_ptr<ContextGraph> context_graph = nullptr;
std::shared_ptr<PostProcessor> post_processor = nullptr;
};
// Torch ASR decoder
class AsrDecoder {
public:
AsrDecoder(std::shared_ptr<FeaturePipeline> feature_pipeline,
std::shared_ptr<DecodeResource> resource,
const DecodeOptions& opts);
// @param block: if true, block when feature is not enough for one chunk
// inference. Otherwise, return kWaitFeats.
DecodeState Decode(bool block = true);
void Rescoring();
void Reset();
void ResetContinuousDecoding();
bool DecodedSomething() const {
return !result_.empty() && !result_[0].sentence.empty();
}
// This method is used for time benchmark
int num_frames_in_current_chunk() const {
return num_frames_in_current_chunk_;
}
int frame_shift_in_ms() const {
return model_->subsampling_rate() *
feature_pipeline_->config().frame_shift * 1000 /
feature_pipeline_->config().sample_rate;
}
int feature_frame_shift_in_ms() const {
return feature_pipeline_->config().frame_shift * 1000 /
feature_pipeline_->config().sample_rate;
}
const std::vector<DecodeResult>& result() const { return result_; }
private:
DecodeState AdvanceDecoding(bool block = true);
void AttentionRescoring();
void UpdateResult(bool finish = false);
std::shared_ptr<FeaturePipeline> feature_pipeline_;
std::shared_ptr<AsrModel> model_;
std::shared_ptr<PostProcessor> post_processor_;
std::shared_ptr<fst::Fst<fst::StdArc>> fst_ = nullptr;
// output symbol table
std::shared_ptr<fst::SymbolTable> symbol_table_;
// e2e unit symbol table
std::shared_ptr<fst::SymbolTable> unit_table_ = nullptr;
const DecodeOptions& opts_;
// cache feature
bool start_ = false;
// For continuous decoding
int num_frames_ = 0;
int global_frame_offset_ = 0;
const int time_stamp_gap_ = 100; // timestamp gap between words in a sentence
std::unique_ptr<SearchInterface> searcher_;
std::unique_ptr<CtcEndpoint> ctc_endpointer_;
int num_frames_in_current_chunk_ = 0;
std::vector<DecodeResult> result_;
public:
WENET_DISALLOW_COPY_AND_ASSIGN(AsrDecoder);
};
} // namespace wenet
#endif // DECODER_ASR_DECODER_H_
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Author: binbin.zhang@horizon.ai (Binbin Zhang)
#include "decoder/asr_model.h"
#include <memory>
#include <utility>
namespace wenet {
int AsrModel::num_frames_for_chunk(bool start) const {
int num_required_frames = 0;
if (chunk_size_ > 0) {
if (!start) { // First batch
int context = right_context_ + 1; // Add current frame
num_required_frames = (chunk_size_ - 1) * subsampling_rate_ + context;
} else {
num_required_frames = chunk_size_ * subsampling_rate_;
}
} else {
num_required_frames = std::numeric_limits<int>::max();
}
return num_required_frames;
}
void AsrModel::CacheFeature(
const std::vector<std::vector<float>>& chunk_feats) {
// Cache feature for next chunk
const int cached_feature_size = 1 + right_context_ - subsampling_rate_;
if (chunk_feats.size() >= cached_feature_size) {
// TODO(Binbin Zhang): Only deal the case when
// chunk_feats.size() > cached_feature_size here, and it's consistent
// with our current model, refine it later if we have new model or
// new requirements
cached_feature_.resize(cached_feature_size);
for (int i = 0; i < cached_feature_size; ++i) {
cached_feature_[i] =
chunk_feats[chunk_feats.size() - cached_feature_size + i];
}
}
}
void AsrModel::ForwardEncoder(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob) {
ctc_prob->clear();
int num_frames = cached_feature_.size() + chunk_feats.size();
if (num_frames >= right_context_ + 1) {
this->ForwardEncoderFunc(chunk_feats, ctc_prob);
this->CacheFeature(chunk_feats);
}
}
} // namespace wenet
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Author: binbin.zhang@horizon.ai (Binbin Zhang)
#ifndef DECODER_ASR_MODEL_H_
#define DECODER_ASR_MODEL_H_
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "utils/timer.h"
#include "utils/utils.h"
namespace wenet {
class AsrModel {
public:
virtual int right_context() const { return right_context_; }
virtual int subsampling_rate() const { return subsampling_rate_; }
virtual int sos() const { return sos_; }
virtual int eos() const { return eos_; }
virtual bool is_bidirectional_decoder() const {
return is_bidirectional_decoder_;
}
virtual int offset() const { return offset_; }
// If chunk_size > 0, streaming case. Otherwise, none streaming case
virtual void set_chunk_size(int chunk_size) { chunk_size_ = chunk_size; }
virtual void set_num_left_chunks(int num_left_chunks) {
num_left_chunks_ = num_left_chunks;
}
// start: if it is the start chunk of one sentence
virtual int num_frames_for_chunk(bool start) const;
virtual void Reset() = 0;
virtual void ForwardEncoder(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob);
virtual void AttentionRescoring(const std::vector<std::vector<int>>& hyps,
float reverse_weight,
std::vector<float>* rescoring_score) = 0;
virtual std::shared_ptr<AsrModel> Copy() const = 0;
protected:
virtual void ForwardEncoderFunc(
const std::vector<std::vector<float>>& chunk_feats,
std::vector<std::vector<float>>* ctc_prob) = 0;
virtual void CacheFeature(const std::vector<std::vector<float>>& chunk_feats);
int right_context_ = 1;
int subsampling_rate_ = 1;
int sos_ = 0;
int eos_ = 0;
bool is_bidirectional_decoder_ = false;
int chunk_size_ = 16;
int num_left_chunks_ = -1; // -1 means all left chunks
int offset_ = 0;
std::vector<std::vector<float>> cached_feature_;
};
} // namespace wenet
#endif // DECODER_ASR_MODEL_H_
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/context_graph.h"
#include <utility>
#include "fst/determinize.h"
#include "utils/string.h"
#include "utils/utils.h"
namespace wenet {
ContextGraph::ContextGraph(ContextConfig config) : config_(config) {}
void ContextGraph::BuildContextGraph(
const std::vector<std::string>& query_contexts,
const std::shared_ptr<fst::SymbolTable>& symbol_table) {
CHECK(symbol_table != nullptr) << "Symbols table should not be nullptr!";
start_tag_id_ = symbol_table->AddSymbol("<context>");
end_tag_id_ = symbol_table->AddSymbol("</context>");
symbol_table_ = symbol_table;
if (query_contexts.empty()) {
if (graph_ != nullptr) graph_.reset();
return;
}
std::unique_ptr<fst::StdVectorFst> ofst(new fst::StdVectorFst());
// State 0 is the start state and the final state.
int start_state = ofst->AddState();
ofst->SetStart(start_state);
ofst->SetFinal(start_state, fst::StdArc::Weight::One());
LOG(INFO) << "Contexts count size: " << query_contexts.size();
int count = 0;
for (const auto& context : query_contexts) {
if (context.size() > config_.max_context_length) {
LOG(INFO) << "Skip long context: " << context;
continue;
}
if (++count > config_.max_contexts) break;
std::vector<std::string> words;
// Split context to words by symbol table, and build the context graph.
bool no_oov = SplitUTF8StringToWords(Trim(context), symbol_table, &words);
if (!no_oov) {
LOG(WARNING) << "Ignore unknown word found during compilation.";
continue;
}
int prev_state = start_state;
int next_state = start_state;
float escape_score = 0;
for (size_t i = 0; i < words.size(); ++i) {
int word_id = symbol_table_->Find(words[i]);
float score = (i * config_.incremental_context_score
+ config_.context_score) * UTF8StringLength(words[i]);
next_state = (i < words.size() - 1) ? ofst->AddState() : start_state;
ofst->AddArc(prev_state,
fst::StdArc(word_id, word_id, score, next_state));
// Add escape arc to clean the previous context score.
if (i > 0) {
// ilabel and olabel of the escape arc is 0 (<epsilon>).
ofst->AddArc(prev_state, fst::StdArc(0, 0, -escape_score, start_state));
}
prev_state = next_state;
escape_score += score;
}
}
std::unique_ptr<fst::StdVectorFst> det_fst(new fst::StdVectorFst());
fst::Determinize(*ofst, det_fst.get());
graph_ = std::move(det_fst);
}
int ContextGraph::GetNextState(int cur_state, int word_id, float* score,
bool* is_start_boundary, bool* is_end_boundary) {
int next_state = 0;
for (fst::ArcIterator<fst::StdFst> aiter(*graph_, cur_state); !aiter.Done();
aiter.Next()) {
const fst::StdArc& arc = aiter.Value();
if (arc.ilabel == 0) {
// escape score, will be overwritten when ilabel equals to word id.
*score = arc.weight.Value();
} else if (arc.ilabel == word_id) {
next_state = arc.nextstate;
*score = arc.weight.Value();
if (cur_state == 0) {
*is_start_boundary = true;
}
if (graph_->Final(arc.nextstate) == fst::StdArc::Weight::One()) {
*is_end_boundary = true;
}
break;
}
}
return next_state;
}
bool ContextGraph::SplitUTF8StringToWords(
const std::string& str,
const std::shared_ptr<fst::SymbolTable>& symbol_table,
std::vector<std::string>* words) {
std::vector<std::string> chars;
SplitUTF8StringToChars(Trim(str), &chars);
bool no_oov = true;
for (size_t start = 0; start < chars.size();) {
for (size_t end = chars.size(); end > start; --end) {
std::string word;
for (size_t i = start; i < end; i++) {
word += chars[i];
}
// Skip space.
if (word == " ") {
start = end;
continue;
}
// Add '▁' at the beginning of English word.
if (IsAlpha(word)) {
word = kSpaceSymbol + word;
}
if (symbol_table->Find(word) != -1) {
words->emplace_back(word);
start = end;
continue;
}
if (end == start + 1) {
++start;
no_oov = false;
LOG(WARNING) << word << " is oov.";
}
}
}
return no_oov;
}
} // namespace wenet
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CONTEXT_GRAPH_H_
#define DECODER_CONTEXT_GRAPH_H_
#include <memory>
#include <string>
#include <vector>
#include "fst/compose.h"
#include "fst/fst.h"
#include "fst/vector-fst.h"
namespace wenet {
using StateId = fst::StdArc::StateId;
struct ContextConfig {
int max_contexts = 5000;
int max_context_length = 100;
float context_score = 3.0;
float incremental_context_score = 0.0;
};
class ContextGraph {
public:
explicit ContextGraph(ContextConfig config);
void BuildContextGraph(const std::vector<std::string>& query_context,
const std::shared_ptr<fst::SymbolTable>& symbol_table);
int GetNextState(int cur_state, int word_id, float* score,
bool* is_start_boundary, bool* is_end_boundary);
int start_tag_id() { return start_tag_id_; }
int end_tag_id() { return end_tag_id_; }
private:
bool SplitUTF8StringToWords(
const std::string& str,
const std::shared_ptr<fst::SymbolTable>& symbol_table,
std::vector<std::string>* words);
int start_tag_id_ = -1;
int end_tag_id_ = -1;
ContextConfig config_;
std::shared_ptr<fst::SymbolTable> symbol_table_ = nullptr;
std::unique_ptr<fst::StdVectorFst> graph_ = nullptr;
DISALLOW_COPY_AND_ASSIGN(ContextGraph);
};
} // namespace wenet
#endif // DECODER_CONTEXT_GRAPH_H_
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_endpoint.h"
#include <math.h>
#include <string>
#include <vector>
#include "utils/log.h"
namespace wenet {
CtcEndpoint::CtcEndpoint(const CtcEndpointConfig& config) : config_(config) {
Reset();
}
void CtcEndpoint::Reset() {
num_frames_decoded_ = 0;
num_frames_trailing_blank_ = 0;
}
static bool RuleActivated(const CtcEndpointRule& rule,
const std::string& rule_name, bool decoded_sth,
int trailing_silence, int utterance_length) {
bool ans = (decoded_sth || !rule.must_decoded_sth) &&
trailing_silence >= rule.min_trailing_silence &&
utterance_length >= rule.min_utterance_length;
if (ans) {
VLOG(2) << "Endpointing rule " << rule_name
<< " activated: " << (decoded_sth ? "true" : "false") << ','
<< trailing_silence << ',' << utterance_length;
}
return ans;
}
bool CtcEndpoint::IsEndpoint(
const std::vector<std::vector<float>>& ctc_log_probs,
bool decoded_something) {
for (int t = 0; t < ctc_log_probs.size(); ++t) {
const auto& logp_t = ctc_log_probs[t];
float blank_prob = expf(logp_t[config_.blank]);
num_frames_decoded_++;
if (blank_prob > config_.blank_threshold) {
num_frames_trailing_blank_++;
} else {
num_frames_trailing_blank_ = 0;
}
}
CHECK_GE(num_frames_decoded_, num_frames_trailing_blank_);
CHECK_GT(frame_shift_in_ms_, 0);
int utterance_length = num_frames_decoded_ * frame_shift_in_ms_;
int trailing_silence = num_frames_trailing_blank_ * frame_shift_in_ms_;
if (RuleActivated(config_.rule1, "rule1", decoded_something, trailing_silence,
utterance_length))
return true;
if (RuleActivated(config_.rule2, "rule2", decoded_something, trailing_silence,
utterance_length))
return true;
if (RuleActivated(config_.rule3, "rule3", decoded_something, trailing_silence,
utterance_length))
return true;
return false;
}
} // namespace wenet
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CTC_ENDPOINT_H_
#define DECODER_CTC_ENDPOINT_H_
#include <vector>
namespace wenet {
struct CtcEndpointRule {
bool must_decoded_sth;
int min_trailing_silence;
int min_utterance_length;
CtcEndpointRule(bool must_decoded_sth = true, int min_trailing_silence = 1000,
int min_utterance_length = 0)
: must_decoded_sth(must_decoded_sth),
min_trailing_silence(min_trailing_silence),
min_utterance_length(min_utterance_length) {}
};
struct CtcEndpointConfig {
/// We consider blank as silence for purposes of endpointing.
int blank = 0; // blank id
float blank_threshold = 0.8; // blank threshold to be silence
/// We support three rules. We terminate decoding if ANY of these rules
/// evaluates to "true". If you want to add more rules, do it by changing this
/// code. If you want to disable a rule, you can set the silence-timeout for
/// that rule to a very large number.
/// rule1 times out after 5000 ms of silence, even if we decoded nothing.
CtcEndpointRule rule1;
/// rule2 times out after 1000 ms of silence after decoding something.
CtcEndpointRule rule2;
/// rule3 times out after the utterance is 20000 ms long, regardless of
/// anything else.
CtcEndpointRule rule3;
CtcEndpointConfig()
: rule1(false, 5000, 0), rule2(true, 1000, 0), rule3(false, 0, 20000) {}
};
class CtcEndpoint {
public:
explicit CtcEndpoint(const CtcEndpointConfig& config);
void Reset();
/// This function returns true if this set of endpointing rules thinks we
/// should terminate decoding.
bool IsEndpoint(const std::vector<std::vector<float>>& ctc_log_probs,
bool decoded_something);
void frame_shift_in_ms(int frame_shift_in_ms) {
frame_shift_in_ms_ = frame_shift_in_ms;
}
private:
CtcEndpointConfig config_;
int frame_shift_in_ms_ = -1;
int num_frames_decoded_ = 0;
int num_frames_trailing_blank_ = 0;
};
} // namespace wenet
#endif // DECODER_CTC_ENDPOINT_H_
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_prefix_beam_search.h"
#include <algorithm>
#include <tuple>
#include <unordered_map>
#include <utility>
#include "utils/log.h"
#include "utils/utils.h"
namespace wenet {
CtcPrefixBeamSearch::CtcPrefixBeamSearch(
const CtcPrefixBeamSearchOptions& opts,
const std::shared_ptr<ContextGraph>& context_graph)
: opts_(opts), context_graph_(context_graph) {
Reset();
}
void CtcPrefixBeamSearch::Reset() {
hypotheses_.clear();
likelihood_.clear();
cur_hyps_.clear();
viterbi_likelihood_.clear();
times_.clear();
outputs_.clear();
abs_time_step_ = 0;
PrefixScore prefix_score;
prefix_score.s = 0.0;
prefix_score.ns = -kFloatMax;
prefix_score.v_s = 0.0;
prefix_score.v_ns = 0.0;
std::vector<int> empty;
cur_hyps_[empty] = prefix_score;
outputs_.emplace_back(empty);
hypotheses_.emplace_back(empty);
likelihood_.emplace_back(prefix_score.total_score());
times_.emplace_back(empty);
}
static bool PrefixScoreCompare(
const std::pair<std::vector<int>, PrefixScore>& a,
const std::pair<std::vector<int>, PrefixScore>& b) {
return a.second.total_score() > b.second.total_score();
}
void CtcPrefixBeamSearch::UpdateOutputs(
const std::pair<std::vector<int>, PrefixScore>& prefix) {
const std::vector<int>& input = prefix.first;
const std::vector<int>& start_boundaries = prefix.second.start_boundaries;
const std::vector<int>& end_boundaries = prefix.second.end_boundaries;
std::vector<int> output;
int s = 0;
int e = 0;
for (int i = 0; i < input.size(); ++i) {
if (s < start_boundaries.size() && i == start_boundaries[s]) {
output.emplace_back(context_graph_->start_tag_id());
++s;
}
output.emplace_back(input[i]);
if (e < end_boundaries.size() && i == end_boundaries[e]) {
output.emplace_back(context_graph_->end_tag_id());
++e;
}
}
outputs_.emplace_back(output);
}
void CtcPrefixBeamSearch::UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& hpys) {
cur_hyps_.clear();
outputs_.clear();
hypotheses_.clear();
likelihood_.clear();
viterbi_likelihood_.clear();
times_.clear();
for (auto& item : hpys) {
cur_hyps_[item.first] = item.second;
UpdateOutputs(item);
hypotheses_.emplace_back(std::move(item.first));
likelihood_.emplace_back(item.second.total_score());
viterbi_likelihood_.emplace_back(item.second.viterbi_score());
times_.emplace_back(item.second.times());
}
}
// Please refer https://robin1001.github.io/2020/12/11/ctc-search
// for how CTC prefix beam search works, and there is a simple graph demo in
// it.
void CtcPrefixBeamSearch::Search(const std::vector<std::vector<float>>& logp) {
if (logp.size() == 0) return;
int first_beam_size =
std::min(static_cast<int>(logp[0].size()), opts_.first_beam_size);
for (int t = 0; t < logp.size(); ++t, ++abs_time_step_) {
const std::vector<float>& logp_t = logp[t];
std::unordered_map<std::vector<int>, PrefixScore, PrefixHash> next_hyps;
// 1. First beam prune, only select topk candidates
std::vector<float> topk_score;
std::vector<int32_t> topk_index;
TopK(logp_t, first_beam_size, &topk_score, &topk_index);
// 2. Token passing
for (int i = 0; i < topk_index.size(); ++i) {
int id = topk_index[i];
auto prob = topk_score[i];
for (const auto& it : cur_hyps_) {
const std::vector<int>& prefix = it.first;
const PrefixScore& prefix_score = it.second;
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert
// PrefixScore(-inf, -inf) by default, since the default constructor
// of PrefixScore will set fields s(blank ending score) and
// ns(none blank ending score) to -inf, respectively.
if (id == opts_.blank) {
// Case 0: *a + ε => *a
PrefixScore& next_score = next_hyps[prefix];
next_score.s = LogAdd(next_score.s, prefix_score.score() + prob);
next_score.v_s = prefix_score.viterbi_score() + prob;
next_score.times_s = prefix_score.times();
// Prefix not changed, copy the context from prefix.
if (context_graph_ && !next_score.has_context) {
next_score.CopyContext(prefix_score);
next_score.has_context = true;
}
} else if (!prefix.empty() && id == prefix.back()) {
// Case 1: *a + a => *a
PrefixScore& next_score1 = next_hyps[prefix];
next_score1.ns = LogAdd(next_score1.ns, prefix_score.ns + prob);
if (next_score1.v_ns < prefix_score.v_ns + prob) {
next_score1.v_ns = prefix_score.v_ns + prob;
if (next_score1.cur_token_prob < prob) {
next_score1.cur_token_prob = prob;
next_score1.times_ns = prefix_score.times_ns;
CHECK_GT(next_score1.times_ns.size(), 0);
next_score1.times_ns.back() = abs_time_step_;
}
}
if (context_graph_ && !next_score1.has_context) {
next_score1.CopyContext(prefix_score);
next_score1.has_context = true;
}
// Case 2: *aε + a => *aa
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score2 = next_hyps[new_prefix];
next_score2.ns = LogAdd(next_score2.ns, prefix_score.s + prob);
if (next_score2.v_ns < prefix_score.v_s + prob) {
next_score2.v_ns = prefix_score.v_s + prob;
next_score2.cur_token_prob = prob;
next_score2.times_ns = prefix_score.times_s;
next_score2.times_ns.emplace_back(abs_time_step_);
}
if (context_graph_ && !next_score2.has_context) {
// Prefix changed, calculate the context score.
next_score2.UpdateContext(context_graph_, prefix_score, id,
prefix.size());
next_score2.has_context = true;
}
} else {
// Case 3: *a + b => *ab, *aε + b => *ab
std::vector<int> new_prefix(prefix);
new_prefix.emplace_back(id);
PrefixScore& next_score = next_hyps[new_prefix];
next_score.ns = LogAdd(next_score.ns, prefix_score.score() + prob);
if (next_score.v_ns < prefix_score.viterbi_score() + prob) {
next_score.v_ns = prefix_score.viterbi_score() + prob;
next_score.cur_token_prob = prob;
next_score.times_ns = prefix_score.times();
next_score.times_ns.emplace_back(abs_time_step_);
}
if (context_graph_ && !next_score.has_context) {
// Calculate the context score.
next_score.UpdateContext(context_graph_, prefix_score, id,
prefix.size());
next_score.has_context = true;
}
}
}
}
// 3. Second beam prune, only keep top n best paths
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(next_hyps.begin(),
next_hyps.end());
int second_beam_size =
std::min(static_cast<int>(arr.size()), opts_.second_beam_size);
std::nth_element(arr.begin(), arr.begin() + second_beam_size, arr.end(),
PrefixScoreCompare);
arr.resize(second_beam_size);
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// 4. Update cur_hyps_ and get new result
UpdateHypotheses(arr);
}
}
void CtcPrefixBeamSearch::FinalizeSearch() { UpdateFinalContext(); }
void CtcPrefixBeamSearch::UpdateFinalContext() {
if (context_graph_ == nullptr) return;
CHECK_EQ(hypotheses_.size(), cur_hyps_.size());
CHECK_EQ(hypotheses_.size(), likelihood_.size());
// We should backoff the context score/state when the context is
// not fully matched at the last time.
for (const auto& prefix : hypotheses_) {
PrefixScore& prefix_score = cur_hyps_[prefix];
if (prefix_score.context_state != 0) {
prefix_score.UpdateContext(context_graph_, prefix_score, 0,
prefix.size());
}
}
std::vector<std::pair<std::vector<int>, PrefixScore>> arr(cur_hyps_.begin(),
cur_hyps_.end());
std::sort(arr.begin(), arr.end(), PrefixScoreCompare);
// Update cur_hyps_ and get new result
UpdateHypotheses(arr);
}
} // namespace wenet
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CTC_PREFIX_BEAM_SEARCH_H_
#define DECODER_CTC_PREFIX_BEAM_SEARCH_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "decoder/context_graph.h"
#include "decoder/search_interface.h"
#include "utils/utils.h"
namespace wenet {
struct CtcPrefixBeamSearchOptions {
int blank = 0; // blank id
int first_beam_size = 10;
int second_beam_size = 10;
};
struct PrefixScore {
float s = -kFloatMax; // blank ending score
float ns = -kFloatMax; // none blank ending score
float v_s = -kFloatMax; // viterbi blank ending score
float v_ns = -kFloatMax; // viterbi none blank ending score
float cur_token_prob = -kFloatMax; // prob of current token
std::vector<int> times_s; // times of viterbi blank path
std::vector<int> times_ns; // times of viterbi none blank path
float score() const { return LogAdd(s, ns); }
float viterbi_score() const { return v_s > v_ns ? v_s : v_ns; }
const std::vector<int>& times() const {
return v_s > v_ns ? times_s : times_ns;
}
bool has_context = false;
int context_state = 0;
float context_score = 0;
std::vector<int> start_boundaries;
std::vector<int> end_boundaries;
void CopyContext(const PrefixScore& prefix_score) {
context_state = prefix_score.context_state;
context_score = prefix_score.context_score;
start_boundaries = prefix_score.start_boundaries;
end_boundaries = prefix_score.end_boundaries;
}
void UpdateContext(const std::shared_ptr<ContextGraph>& context_graph,
const PrefixScore& prefix_score, int word_id,
int prefix_len) {
this->CopyContext(prefix_score);
float score = 0;
bool is_start_boundary = false;
bool is_end_boundary = false;
context_state =
context_graph->GetNextState(prefix_score.context_state, word_id, &score,
&is_start_boundary, &is_end_boundary);
context_score += score;
if (is_start_boundary) start_boundaries.emplace_back(prefix_len);
if (is_end_boundary) end_boundaries.emplace_back(prefix_len);
}
float total_score() const { return score() + context_score; }
};
struct PrefixHash {
size_t operator()(const std::vector<int>& prefix) const {
size_t hash_code = 0;
// here we use KB&DR hash code
for (int id : prefix) {
hash_code = id + 31 * hash_code;
}
return hash_code;
}
};
class CtcPrefixBeamSearch : public SearchInterface {
public:
explicit CtcPrefixBeamSearch(
const CtcPrefixBeamSearchOptions& opts,
const std::shared_ptr<ContextGraph>& context_graph = nullptr);
void Search(const std::vector<std::vector<float>>& logp) override;
void Reset() override;
void FinalizeSearch() override;
SearchType Type() const override { return SearchType::kPrefixBeamSearch; }
void UpdateOutputs(const std::pair<std::vector<int>, PrefixScore>& prefix);
void UpdateHypotheses(
const std::vector<std::pair<std::vector<int>, PrefixScore>>& hpys);
void UpdateFinalContext();
const std::vector<float>& viterbi_likelihood() const {
return viterbi_likelihood_;
}
const std::vector<std::vector<int>>& Inputs() const override {
return hypotheses_;
}
const std::vector<std::vector<int>>& Outputs() const override {
return outputs_;
}
const std::vector<float>& Likelihood() const override { return likelihood_; }
const std::vector<std::vector<int>>& Times() const override { return times_; }
private:
int abs_time_step_ = 0;
// N-best list and corresponding likelihood_, in sorted order
std::vector<std::vector<int>> hypotheses_;
std::vector<float> likelihood_;
std::vector<float> viterbi_likelihood_;
std::vector<std::vector<int>> times_;
std::unordered_map<std::vector<int>, PrefixScore, PrefixHash> cur_hyps_;
std::shared_ptr<ContextGraph> context_graph_ = nullptr;
// Outputs contain the hypotheses_ and tags like: <context> and </context>
std::vector<std::vector<int>> outputs_;
const CtcPrefixBeamSearchOptions& opts_;
public:
WENET_DISALLOW_COPY_AND_ASSIGN(CtcPrefixBeamSearch);
};
} // namespace wenet
#endif // DECODER_CTC_PREFIX_BEAM_SEARCH_H_
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_wfst_beam_search.h"
#include <utility>
namespace wenet {
void DecodableTensorScaled::Reset() {
num_frames_ready_ = 0;
done_ = false;
// Give an empty initialization, will throw error when
// AcceptLoglikes is not called
logp_.clear();
}
void DecodableTensorScaled::AcceptLoglikes(const std::vector<float>& logp) {
++num_frames_ready_;
// TODO(Binbin Zhang): Avoid copy here
logp_ = logp;
}
float DecodableTensorScaled::LogLikelihood(int32 frame, int32 index) {
CHECK_GT(index, 0);
CHECK_LT(frame, num_frames_ready_);
return scale_ * logp_[index - 1];
}
bool DecodableTensorScaled::IsLastFrame(int32 frame) const {
CHECK_LT(frame, num_frames_ready_);
return done_ && (frame == num_frames_ready_ - 1);
}
int32 DecodableTensorScaled::NumIndices() const {
LOG(FATAL) << "Not implement";
return 0;
}
CtcWfstBeamSearch::CtcWfstBeamSearch(
const fst::Fst<fst::StdArc>& fst, const CtcWfstBeamSearchOptions& opts,
const std::shared_ptr<ContextGraph>& context_graph)
: decodable_(opts.acoustic_scale),
decoder_(fst, opts, context_graph),
context_graph_(context_graph),
opts_(opts) {
Reset();
}
void CtcWfstBeamSearch::Reset() {
num_frames_ = 0;
decoded_frames_mapping_.clear();
is_last_frame_blank_ = false;
last_best_ = 0;
inputs_.clear();
outputs_.clear();
likelihood_.clear();
times_.clear();
decodable_.Reset();
decoder_.InitDecoding();
}
void CtcWfstBeamSearch::Search(const std::vector<std::vector<float>>& logp) {
if (0 == logp.size()) {
return;
}
// Every time we get the log posterior, we decode it all before return
for (int i = 0; i < logp.size(); i++) {
float blank_score = std::exp(logp[i][0]);
if (blank_score > opts_.blank_skip_thresh * opts_.blank_scale) {
VLOG(3) << "skipping frame " << num_frames_ << " score " << blank_score;
is_last_frame_blank_ = true;
last_frame_prob_ = logp[i];
} else {
// Get the best symbol
int cur_best =
std::max_element(logp[i].begin(), logp[i].end()) - logp[i].begin();
// Optional, adding one blank frame if we has skipped it in two same
// symbols
if (cur_best != 0 && is_last_frame_blank_ && cur_best == last_best_) {
decodable_.AcceptLoglikes(last_frame_prob_);
decoder_.AdvanceDecoding(&decodable_, 1);
decoded_frames_mapping_.push_back(num_frames_ - 1);
VLOG(2) << "Adding blank frame at symbol " << cur_best;
}
last_best_ = cur_best;
decodable_.AcceptLoglikes(logp[i]);
decoder_.AdvanceDecoding(&decodable_, 1);
decoded_frames_mapping_.push_back(num_frames_);
is_last_frame_blank_ = false;
}
num_frames_++;
}
// Get the best path
inputs_.clear();
outputs_.clear();
likelihood_.clear();
if (decoded_frames_mapping_.size() > 0) {
inputs_.resize(1);
outputs_.resize(1);
likelihood_.resize(1);
kaldi::Lattice lat;
decoder_.GetBestPath(&lat, false);
std::vector<int> alignment;
kaldi::LatticeWeight weight;
fst::GetLinearSymbolSequence(lat, &alignment, &outputs_[0], &weight);
ConvertToInputs(alignment, &inputs_[0]);
RemoveContinuousTags(&outputs_[0]);
VLOG(3) << weight.Value1() << " " << weight.Value2();
likelihood_[0] = -(weight.Value1() + weight.Value2());
}
}
void CtcWfstBeamSearch::FinalizeSearch() {
decodable_.SetFinish();
decoder_.FinalizeDecoding();
inputs_.clear();
outputs_.clear();
likelihood_.clear();
times_.clear();
if (decoded_frames_mapping_.size() > 0) {
std::vector<kaldi::Lattice> nbest_lats;
if (opts_.nbest == 1) {
kaldi::Lattice lat;
decoder_.GetBestPath(&lat, true);
nbest_lats.push_back(std::move(lat));
} else {
// Get N-best path by lattice(CompactLattice)
kaldi::CompactLattice clat;
decoder_.GetLattice(&clat, true);
kaldi::Lattice lat, nbest_lat;
fst::ConvertLattice(clat, &lat);
// TODO(Binbin Zhang): it's n-best word lists here, not character n-best
fst::ShortestPath(lat, &nbest_lat, opts_.nbest);
fst::ConvertNbestToVector(nbest_lat, &nbest_lats);
}
int nbest = nbest_lats.size();
inputs_.resize(nbest);
outputs_.resize(nbest);
likelihood_.resize(nbest);
times_.resize(nbest);
for (int i = 0; i < nbest; i++) {
kaldi::LatticeWeight weight;
std::vector<int> alignment;
fst::GetLinearSymbolSequence(nbest_lats[i], &alignment, &outputs_[i],
&weight);
ConvertToInputs(alignment, &inputs_[i], &times_[i]);
RemoveContinuousTags(&outputs_[i]);
likelihood_[i] = -(weight.Value1() + weight.Value2());
}
}
}
void CtcWfstBeamSearch::ConvertToInputs(const std::vector<int>& alignment,
std::vector<int>* input,
std::vector<int>* time) {
input->clear();
if (time != nullptr) time->clear();
for (int cur = 0; cur < alignment.size(); ++cur) {
// ignore blank
if (alignment[cur] - 1 == 0) continue;
// merge continuous same label
if (cur > 0 && alignment[cur] == alignment[cur - 1]) continue;
input->push_back(alignment[cur] - 1);
if (time != nullptr) {
time->push_back(decoded_frames_mapping_[cur]);
}
}
}
void CtcWfstBeamSearch::RemoveContinuousTags(std::vector<int>* output) {
if (context_graph_) {
for (auto it = output->begin(); it != output->end();) {
if (*it == context_graph_->start_tag_id() ||
*it == context_graph_->end_tag_id()) {
if (it + 1 != output->end() && *it == *(it + 1)) {
it = output->erase(it);
continue;
}
}
++it;
}
}
}
} // namespace wenet
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment