Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Conformer_pytorch
Commits
764b3a75
Commit
764b3a75
authored
Jun 07, 2023
by
Sugon_ldc
Browse files
add new model
parents
Changes
498
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1741 additions
and
0 deletions
+1741
-0
runtime/core/cmake/glog.cmake
runtime/core/cmake/glog.cmake
+7
-0
runtime/core/cmake/grpc.cmake
runtime/core/cmake/grpc.cmake
+10
-0
runtime/core/cmake/gtest.cmake
runtime/core/cmake/gtest.cmake
+9
-0
runtime/core/cmake/libtorch.cmake
runtime/core/cmake/libtorch.cmake
+79
-0
runtime/core/cmake/onnx.cmake
runtime/core/cmake/onnx.cmake
+35
-0
runtime/core/cmake/openfst.cmake
runtime/core/cmake/openfst.cmake
+45
-0
runtime/core/cmake/pybind11.cmake
runtime/core/cmake/pybind11.cmake
+8
-0
runtime/core/cmake/xpu.cmake
runtime/core/cmake/xpu.cmake
+37
-0
runtime/core/decoder/CMakeLists.txt
runtime/core/decoder/CMakeLists.txt
+39
-0
runtime/core/decoder/asr_decoder.cc
runtime/core/decoder/asr_decoder.cc
+231
-0
runtime/core/decoder/asr_decoder.h
runtime/core/decoder/asr_decoder.h
+166
-0
runtime/core/decoder/asr_model.cc
runtime/core/decoder/asr_model.cc
+54
-0
runtime/core/decoder/asr_model.h
runtime/core/decoder/asr_model.h
+68
-0
runtime/core/decoder/context_graph.cc
runtime/core/decoder/context_graph.cc
+151
-0
runtime/core/decoder/context_graph.h
runtime/core/decoder/context_graph.h
+65
-0
runtime/core/decoder/ctc_endpoint.cc
runtime/core/decoder/ctc_endpoint.cc
+80
-0
runtime/core/decoder/ctc_endpoint.h
runtime/core/decoder/ctc_endpoint.h
+79
-0
runtime/core/decoder/ctc_prefix_beam_search.cc
runtime/core/decoder/ctc_prefix_beam_search.cc
+235
-0
runtime/core/decoder/ctc_prefix_beam_search.h
runtime/core/decoder/ctc_prefix_beam_search.h
+143
-0
runtime/core/decoder/ctc_wfst_beam_search.cc
runtime/core/decoder/ctc_wfst_beam_search.cc
+200
-0
No files found.
Too many changes to show.
To preserve performance only
498 of 498+
files are displayed.
Plain diff
Email patch
runtime/core/cmake/glog.cmake
0 → 100644
View file @
764b3a75
FetchContent_Declare
(
glog
URL https://github.com/google/glog/archive/v0.4.0.zip
URL_HASH SHA256=9e1b54eb2782f53cd8af107ecf08d2ab64b8d0dc2b7f5594472f3bd63ca85cdc
)
FetchContent_MakeAvailable
(
glog
)
include_directories
(
${
glog_SOURCE_DIR
}
/src
${
glog_BINARY_DIR
}
)
\ No newline at end of file
runtime/core/cmake/grpc.cmake
0 → 100644
View file @
764b3a75
include_directories
(
${
CMAKE_CURRENT_SOURCE_DIR
}
/grpc
)
# third_party: grpc
# On how to build grpc, you may refer to https://github.com/grpc/grpc
# We recommend manually recursive clone the repo to avoid internet connection problem
FetchContent_Declare
(
gRPC
GIT_REPOSITORY https://github.com/grpc/grpc
GIT_TAG v1.37.1
)
FetchContent_MakeAvailable
(
gRPC
)
\ No newline at end of file
runtime/core/cmake/gtest.cmake
0 → 100644
View file @
764b3a75
FetchContent_Declare
(
googletest
URL https://github.com/google/googletest/archive/release-1.11.0.zip
URL_HASH SHA256=353571c2440176ded91c2de6d6cd88ddd41401d14692ec1f99e35d013feda55a
)
if
(
MSVC
)
set
(
gtest_force_shared_crt ON CACHE BOOL
"Always use msvcrt.dll"
FORCE
)
endif
()
FetchContent_MakeAvailable
(
googletest
)
\ No newline at end of file
runtime/core/cmake/libtorch.cmake
0 → 100644
View file @
764b3a75
if
(
TORCH
)
add_definitions
(
-DUSE_TORCH
)
if
(
NOT ANDROID
)
if
(
GPU
)
if
(
NOT
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"Linux"
)
message
(
FATAL_ERROR
"GPU is supported only Linux, you can use CPU version"
)
else
()
add_definitions
(
-DUSE_GPU
)
endif
()
endif
()
if
(
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"Windows"
)
if
(
${
CMAKE_BUILD_TYPE
}
MATCHES
"Release"
)
set
(
LIBTORCH_URL
"https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-1.13.0%2Bcpu.zip"
)
set
(
URL_HASH
"SHA256=bece54d36377990257e9d028c687c5b6759c5cfec0a0153da83cf6f0f71f648f"
)
else
()
set
(
LIBTORCH_URL
"https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-debug-1.13.0%2Bcpu.zip"
)
set
(
URL_HASH
"SHA256=3cc7ba3c3865d86f03d78c2f0878fdbed8b764359476397a5c95cf3bba0d665a"
)
endif
()
elseif
(
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"Linux"
)
if
(
CXX11_ABI
)
if
(
NOT GPU
)
set
(
LIBTORCH_URL
"https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-1.13.0%2Bcpu.zip"
)
set
(
URL_HASH
"SHA256=d52f63577a07adb0bfd6d77c90f7da21896e94f71eb7dcd55ed7835ccb3b2b59"
)
else
()
set
(
LIBTORCH_URL
"https://download.pytorch.org/libtorch/cu113/libtorch-cxx11-abi-shared-with-deps-1.12.0%2Bcu113.zip"
)
set
(
URL_HASH
"SHA256=80f089939de20e68e3fcad4dfa72a26c8bf91b5e77b11042f671f39ebac35865"
)
endif
()
else
()
if
(
NOT GPU
)
set
(
LIBTORCH_URL
"https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-1.13.0%2Bcpu.zip"
)
set
(
URL_HASH
"SHA256=bee1b7be308792aa60fc95a4f5274d9658cb7248002d0e333d49eb81ec88430c"
)
else
()
set
(
LIBTORCH_URL
"https://download.pytorch.org/libtorch/cu113/libtorch-shared-with-deps-1.11.0%2Bcu113.zip"
)
set
(
URL_HASH
"SHA256=90159ecce3ff451f3ef3f657493b6c7c96759c3b74bbd70c1695f2ea2f81e1ad"
)
endif
()
endif
()
elseif
(
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"Darwin"
)
set
(
LIBTORCH_URL
"https://download.pytorch.org/libtorch/cpu/libtorch-macos-1.13.0.zip"
)
set
(
URL_HASH
"SHA256=a8f80050b95489b4e002547910410c2c230e9f590ffab2482e19e809afe4f7aa"
)
elseif
(
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"iOS"
)
add_definitions
(
-DIOS
)
else
()
message
(
FATAL_ERROR
"Unsupported System '
${
CMAKE_SYSTEM_NAME
}
' (expected 'Windows', 'Linux', 'Darwin' or 'iOS')"
)
endif
()
# iOS use LibTorch from pod install
if
(
NOT IOS
)
FetchContent_Declare
(
libtorch
URL
${
LIBTORCH_URL
}
URL_HASH
${
URL_HASH
}
)
FetchContent_MakeAvailable
(
libtorch
)
find_package
(
Torch REQUIRED PATHS
${
libtorch_SOURCE_DIR
}
NO_DEFAULT_PATH
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
${
TORCH_CXX_FLAGS
}
-DC10_USE_GLOG"
)
endif
()
if
(
MSVC
)
file
(
GLOB TORCH_DLLS
"
${
TORCH_INSTALL_PREFIX
}
/lib/*.dll"
)
file
(
COPY
${
TORCH_DLLS
}
DESTINATION
${
CMAKE_BINARY_DIR
}
)
endif
()
else
()
# Change version in runtime/android/app/build.gradle.
file
(
GLOB PYTORCH_INCLUDE_DIRS
"
${
build_DIR
}
/pytorch_android*.aar/headers"
)
file
(
GLOB PYTORCH_LINK_DIRS
"
${
build_DIR
}
/pytorch_android*.aar/jni/
${
ANDROID_ABI
}
"
)
find_library
(
PYTORCH_LIBRARY pytorch_jni
PATHS
${
PYTORCH_LINK_DIRS
}
NO_CMAKE_FIND_ROOT_PATH
)
find_library
(
FBJNI_LIBRARY fbjni
PATHS
${
PYTORCH_LINK_DIRS
}
NO_CMAKE_FIND_ROOT_PATH
)
include_directories
(
${
PYTORCH_INCLUDE_DIRS
}
${
PYTORCH_INCLUDE_DIRS
}
/torch/csrc/api/include
)
endif
()
endif
()
runtime/core/cmake/onnx.cmake
0 → 100644
View file @
764b3a75
if
(
ONNX
)
set
(
ONNX_VERSION
"1.12.0"
)
if
(
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"Windows"
)
set
(
ONNX_URL
"https://github.com/microsoft/onnxruntime/releases/download/v
${
ONNX_VERSION
}
/onnxruntime-win-x64-
${
ONNX_VERSION
}
.zip"
)
set
(
URL_HASH
"SHA256=8b5d61204989350b7904ac277f5fbccd3e6736ddbb6ec001e412723d71c9c176"
)
elseif
(
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"Linux"
)
if
(
CMAKE_SYSTEM_PROCESSOR MATCHES
"aarch64"
)
set
(
ONNX_URL
"https://github.com/microsoft/onnxruntime/releases/download/v
${
ONNX_VERSION
}
/onnxruntime-linux-aarch64-
${
ONNX_VERSION
}
.tgz"
)
set
(
URL_HASH
"SHA256=5820d9f343df73c63b6b2b174a1ff62575032e171c9564bcf92060f46827d0ac"
)
else
()
set
(
ONNX_URL
"https://github.com/microsoft/onnxruntime/releases/download/v
${
ONNX_VERSION
}
/onnxruntime-linux-x64-
${
ONNX_VERSION
}
.tgz"
)
set
(
URL_HASH
"SHA256=5d503ce8540358b59be26c675e42081be14a3e833a5301926f555451046929c5"
)
endif
()
elseif
(
${
CMAKE_SYSTEM_NAME
}
STREQUAL
"Darwin"
)
set
(
ONNX_URL
"https://github.com/microsoft/onnxruntime/releases/download/v
${
ONNX_VERSION
}
/onnxruntime-osx-x86_64-
${
ONNX_VERSION
}
.tgz"
)
set
(
URL_HASH
"SHA256=09b17f712f8c6f19bb63da35d508815b443cbb473e16c6192abfaa297c02f600"
)
else
()
message
(
FATAL_ERROR
"Unsupported CMake System Name '
${
CMAKE_SYSTEM_NAME
}
' (expected 'Windows', 'Linux' or 'Darwin')"
)
endif
()
FetchContent_Declare
(
onnxruntime
URL
${
ONNX_URL
}
URL_HASH
${
URL_HASH
}
)
FetchContent_MakeAvailable
(
onnxruntime
)
include_directories
(
${
onnxruntime_SOURCE_DIR
}
/include
)
link_directories
(
${
onnxruntime_SOURCE_DIR
}
/lib
)
if
(
MSVC
)
file
(
GLOB ONNX_DLLS
"
${
onnxruntime_SOURCE_DIR
}
/lib/*.dll"
)
file
(
COPY
${
ONNX_DLLS
}
DESTINATION
${
CMAKE_BINARY_DIR
}
/bin/
${
CMAKE_BUILD_TYPE
}
)
endif
()
add_definitions
(
-DUSE_ONNX
)
endif
()
runtime/core/cmake/openfst.cmake
0 → 100644
View file @
764b3a75
if
(
NOT ANDROID
)
include
(
gflags
)
# We can't build glog with gflags, unless gflags is pre-installed.
# If build glog with pre-installed gflags, there will be conflict.
set
(
WITH_GFLAGS OFF CACHE BOOL
"whether build glog with gflags"
FORCE
)
include
(
glog
)
if
(
NOT GRAPH_TOOLS
)
set
(
HAVE_BIN OFF CACHE BOOL
"Build the fst binaries"
FORCE
)
set
(
HAVE_SCRIPT OFF CACHE BOOL
"Build the fstscript"
FORCE
)
endif
()
set
(
HAVE_COMPACT OFF CACHE BOOL
"Build compact"
FORCE
)
set
(
HAVE_CONST OFF CACHE BOOL
"Build const"
FORCE
)
set
(
HAVE_GRM OFF CACHE BOOL
"Build grm"
FORCE
)
set
(
HAVE_FAR OFF CACHE BOOL
"Build far"
FORCE
)
set
(
HAVE_PDT OFF CACHE BOOL
"Build pdt"
FORCE
)
set
(
HAVE_MPDT OFF CACHE BOOL
"Build mpdt"
FORCE
)
set
(
HAVE_LINEAR OFF CACHE BOOL
"Build linear"
FORCE
)
set
(
HAVE_LOOKAHEAD OFF CACHE BOOL
"Build lookahead"
FORCE
)
set
(
HAVE_NGRAM OFF CACHE BOOL
"Build ngram"
FORCE
)
set
(
HAVE_SPECIAL OFF CACHE BOOL
"Build special"
FORCE
)
if
(
MSVC
)
add_compile_options
(
/W0 /wd4244 /wd4267
)
endif
()
# "OpenFST port for Windows" builds openfst with cmake for multiple platforms.
# Openfst is compiled with glog/gflags to avoid log and flag conflicts with log and flags in wenet/libtorch.
# To build openfst with gflags and glog, we comment out some vars of {flags, log}.h and flags.cc.
set
(
openfst_SOURCE_DIR
${
fc_base
}
/openfst-src CACHE PATH
"OpenFST source directory"
)
FetchContent_Declare
(
openfst
URL https://github.com/kkm000/openfst/archive/refs/tags/win/1.6.5.1.tar.gz
URL_HASH SHA256=02c49b559c3976a536876063369efc0e41ab374be1035918036474343877046e
PATCH_COMMAND
${
CMAKE_COMMAND
}
-E copy_directory
${
CMAKE_CURRENT_SOURCE_DIR
}
/patch/openfst
${
openfst_SOURCE_DIR
}
)
FetchContent_MakeAvailable
(
openfst
)
add_dependencies
(
fst gflags glog
)
target_link_libraries
(
fst PUBLIC gflags_nothreads_static glog
)
include_directories
(
${
openfst_SOURCE_DIR
}
/src/include
)
else
()
set
(
openfst_BINARY_DIR
${
build_DIR
}
/wenet-openfst-android-1.0.2.aar/jni
)
include_directories
(
${
openfst_BINARY_DIR
}
/include
)
link_directories
(
${
openfst_BINARY_DIR
}
/
${
ANDROID_ABI
}
)
link_libraries
(
log gflags_nothreads glog fst
)
endif
()
runtime/core/cmake/pybind11.cmake
0 → 100644
View file @
764b3a75
FetchContent_Declare
(
pybind11
URL https://github.com/pybind/pybind11/archive/refs/tags/v2.9.2.zip
URL_HASH SHA256=d1646e6f70d8a3acb2ddd85ce1ed543b5dd579c68b8fb8e9638282af20edead8
)
FetchContent_MakeAvailable
(
pybind11
)
add_subdirectory
(
${
pybind11_SOURCE_DIR
}
)
\ No newline at end of file
runtime/core/cmake/xpu.cmake
0 → 100644
View file @
764b3a75
if
(
NOT WIN32
)
string
(
ASCII 27 Esc
)
set
(
ColourReset
"
${
Esc
}
[m"
)
set
(
ColourBold
"
${
Esc
}
[1m"
)
set
(
Red
"
${
Esc
}
[31m"
)
set
(
Green
"
${
Esc
}
[32m"
)
set
(
Yellow
"
${
Esc
}
[33m"
)
set
(
Blue
"
${
Esc
}
[34m"
)
set
(
Magenta
"
${
Esc
}
[35m"
)
set
(
Cyan
"
${
Esc
}
[36m"
)
set
(
White
"
${
Esc
}
[37m"
)
set
(
BoldRed
"
${
Esc
}
[1;31m"
)
set
(
BoldGreen
"
${
Esc
}
[1;32m"
)
set
(
BoldYellow
"
${
Esc
}
[1;33m"
)
set
(
BoldBlue
"
${
Esc
}
[1;34m"
)
set
(
BoldMagenta
"
${
Esc
}
[1;35m"
)
set
(
BoldCyan
"
${
Esc
}
[1;36m"
)
set
(
BoldWhite
"
${
Esc
}
[1;37m"
)
endif
()
if
(
XPU
)
set
(
RUNTIME_KUNLUN_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
)
message
(
STATUS
"RUNTIME_KUNLUN_PATH is
${
RUNTIME_KUNLUN_PATH
}
.
\n
"
)
set
(
KUNLUN_XPU_PATH
${
RUNTIME_KUNLUN_PATH
}
/xpu
)
if
(
NOT DEFINED ENV{XPU_API_PATH}
)
message
(
FATAL_ERROR
"
${
BoldRed
}
NO ENV{XPU_API_PATH} in your env. Please set XPU_API_PATH.
${
ColourReset
}
\n
"
)
else
()
set
(
XPU_API_PATH $ENV{XPU_API_PATH}
)
message
(
"set XPU_API_PATH from env_var. Val is $ENV{XPU_API_PATH}."
)
endif
()
include_directories
(
${
RUNTIME_KUNLUN_PATH
}
${
KUNLUN_XPU_PATH
}
/
${
XPU_API_PATH
}
/output/include
${
XPU_API_PATH
}
/../runtime/include
)
link_directories
(
${
XPU_API_PATH
}
/output/so/
${
XPU_API_PATH
}
/../runtime/output/so/
)
add_definitions
(
-DUSE_XPU
)
endif
()
runtime/core/decoder/CMakeLists.txt
0 → 100644
View file @
764b3a75
set
(
decoder_srcs
asr_decoder.cc
asr_model.cc
context_graph.cc
ctc_prefix_beam_search.cc
ctc_wfst_beam_search.cc
ctc_endpoint.cc
)
if
(
NOT TORCH AND NOT ONNX AND NOT XPU AND NOT IOS AND NOT BPU
)
message
(
FATAL_ERROR
"Please build with TORCH or ONNX or XPU or IOS or BPU!!!"
)
endif
()
if
(
TORCH OR IOS
)
list
(
APPEND decoder_srcs torch_asr_model.cc
)
endif
()
if
(
ONNX
)
list
(
APPEND decoder_srcs onnx_asr_model.cc
)
endif
()
add_library
(
decoder STATIC
${
decoder_srcs
}
)
target_link_libraries
(
decoder PUBLIC kaldi-decoder frontend
post_processor utils
)
if
(
ANDROID
)
target_link_libraries
(
decoder PUBLIC
${
PYTORCH_LIBRARY
}
${
FBJNI_LIBRARY
}
)
else
()
if
(
TORCH
)
target_link_libraries
(
decoder PUBLIC
${
TORCH_LIBRARIES
}
)
endif
()
if
(
ONNX
)
target_link_libraries
(
decoder PUBLIC onnxruntime
)
endif
()
if
(
BPU
)
target_link_libraries
(
decoder PUBLIC bpu_asr_model
)
endif
()
if
(
XPU
)
target_link_libraries
(
decoder PUBLIC xpu_conformer
)
endif
()
endif
()
runtime/core/decoder/asr_decoder.cc
0 → 100644
View file @
764b3a75
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/asr_decoder.h"
#include <ctype.h>
#include <algorithm>
#include <limits>
#include <utility>
#include "utils/timer.h"
namespace
wenet
{
AsrDecoder
::
AsrDecoder
(
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
,
std
::
shared_ptr
<
DecodeResource
>
resource
,
const
DecodeOptions
&
opts
)
:
feature_pipeline_
(
std
::
move
(
feature_pipeline
)),
// Make a copy of the model ASR model since we will change the inner
// status of the model
model_
(
resource
->
model
->
Copy
()),
post_processor_
(
resource
->
post_processor
),
symbol_table_
(
resource
->
symbol_table
),
fst_
(
resource
->
fst
),
unit_table_
(
resource
->
unit_table
),
opts_
(
opts
),
ctc_endpointer_
(
new
CtcEndpoint
(
opts
.
ctc_endpoint_config
))
{
if
(
opts_
.
reverse_weight
>
0
)
{
// Check if model has a right to left decoder
CHECK
(
model_
->
is_bidirectional_decoder
());
}
if
(
nullptr
==
fst_
)
{
searcher_
.
reset
(
new
CtcPrefixBeamSearch
(
opts
.
ctc_prefix_search_opts
,
resource
->
context_graph
));
}
else
{
searcher_
.
reset
(
new
CtcWfstBeamSearch
(
*
fst_
,
opts
.
ctc_wfst_search_opts
,
resource
->
context_graph
));
}
ctc_endpointer_
->
frame_shift_in_ms
(
frame_shift_in_ms
());
}
void
AsrDecoder
::
Reset
()
{
start_
=
false
;
result_
.
clear
();
num_frames_
=
0
;
global_frame_offset_
=
0
;
model_
->
Reset
();
searcher_
->
Reset
();
feature_pipeline_
->
Reset
();
ctc_endpointer_
->
Reset
();
}
void
AsrDecoder
::
ResetContinuousDecoding
()
{
global_frame_offset_
=
num_frames_
;
start_
=
false
;
result_
.
clear
();
model_
->
Reset
();
searcher_
->
Reset
();
ctc_endpointer_
->
Reset
();
}
DecodeState
AsrDecoder
::
Decode
(
bool
block
)
{
return
this
->
AdvanceDecoding
(
block
);
}
void
AsrDecoder
::
Rescoring
()
{
// Do attention rescoring
Timer
timer
;
AttentionRescoring
();
VLOG
(
2
)
<<
"Rescoring cost latency: "
<<
timer
.
Elapsed
()
<<
"ms."
;
}
DecodeState
AsrDecoder
::
AdvanceDecoding
(
bool
block
)
{
DecodeState
state
=
DecodeState
::
kEndBatch
;
model_
->
set_chunk_size
(
opts_
.
chunk_size
);
model_
->
set_num_left_chunks
(
opts_
.
num_left_chunks
);
int
num_required_frames
=
model_
->
num_frames_for_chunk
(
start_
);
std
::
vector
<
std
::
vector
<
float
>>
chunk_feats
;
// Return immediately if we do not want to block
if
(
!
block
&&
!
feature_pipeline_
->
input_finished
()
&&
feature_pipeline_
->
NumQueuedFrames
()
<
num_required_frames
)
{
return
DecodeState
::
kWaitFeats
;
}
// If not okay, that means we reach the end of the input
if
(
!
feature_pipeline_
->
Read
(
num_required_frames
,
&
chunk_feats
))
{
state
=
DecodeState
::
kEndFeats
;
}
num_frames_
+=
chunk_feats
.
size
();
VLOG
(
2
)
<<
"Required "
<<
num_required_frames
<<
" get "
<<
chunk_feats
.
size
();
Timer
timer
;
std
::
vector
<
std
::
vector
<
float
>>
ctc_log_probs
;
model_
->
ForwardEncoder
(
chunk_feats
,
&
ctc_log_probs
);
int
forward_time
=
timer
.
Elapsed
();
if
(
opts_
.
ctc_wfst_search_opts
.
blank_scale
!=
1.0
)
{
for
(
int
i
=
0
;
i
<
ctc_log_probs
.
size
();
i
++
)
{
ctc_log_probs
[
i
][
0
]
=
ctc_log_probs
[
i
][
0
]
+
std
::
log
(
opts_
.
ctc_wfst_search_opts
.
blank_scale
);
}
}
timer
.
Reset
();
searcher_
->
Search
(
ctc_log_probs
);
int
search_time
=
timer
.
Elapsed
();
VLOG
(
3
)
<<
"forward takes "
<<
forward_time
<<
" ms, search takes "
<<
search_time
<<
" ms"
;
UpdateResult
();
if
(
state
!=
DecodeState
::
kEndFeats
)
{
if
(
ctc_endpointer_
->
IsEndpoint
(
ctc_log_probs
,
DecodedSomething
()))
{
VLOG
(
1
)
<<
"Endpoint is detected at "
<<
num_frames_
;
state
=
DecodeState
::
kEndpoint
;
}
}
start_
=
true
;
return
state
;
}
void
AsrDecoder
::
UpdateResult
(
bool
finish
)
{
const
auto
&
hypotheses
=
searcher_
->
Outputs
();
const
auto
&
inputs
=
searcher_
->
Inputs
();
const
auto
&
likelihood
=
searcher_
->
Likelihood
();
const
auto
&
times
=
searcher_
->
Times
();
result_
.
clear
();
CHECK_EQ
(
hypotheses
.
size
(),
likelihood
.
size
());
for
(
size_t
i
=
0
;
i
<
hypotheses
.
size
();
i
++
)
{
const
std
::
vector
<
int
>&
hypothesis
=
hypotheses
[
i
];
DecodeResult
path
;
path
.
score
=
likelihood
[
i
];
int
offset
=
global_frame_offset_
*
feature_frame_shift_in_ms
();
for
(
size_t
j
=
0
;
j
<
hypothesis
.
size
();
j
++
)
{
std
::
string
word
=
symbol_table_
->
Find
(
hypothesis
[
j
]);
// A detailed explanation of this if-else branch can be found in
// https://github.com/wenet-e2e/wenet/issues/583#issuecomment-907994058
if
(
searcher_
->
Type
()
==
kWfstBeamSearch
)
{
path
.
sentence
+=
(
' '
+
word
);
}
else
{
path
.
sentence
+=
(
word
);
}
}
// TimeStamp is only supported in final result
// TimeStamp of the output of CtcWfstBeamSearch may be inaccurate due to
// various FST operations when building the decoding graph. So here we use
// time stamp of the input(e2e model unit), which is more accurate, and it
// requires the symbol table of the e2e model used in training.
if
(
unit_table_
!=
nullptr
&&
finish
)
{
const
std
::
vector
<
int
>&
input
=
inputs
[
i
];
const
std
::
vector
<
int
>&
time_stamp
=
times
[
i
];
CHECK_EQ
(
input
.
size
(),
time_stamp
.
size
());
for
(
size_t
j
=
0
;
j
<
input
.
size
();
j
++
)
{
std
::
string
word
=
unit_table_
->
Find
(
input
[
j
]);
int
start
=
time_stamp
[
j
]
*
frame_shift_in_ms
()
-
time_stamp_gap_
>
0
?
time_stamp
[
j
]
*
frame_shift_in_ms
()
-
time_stamp_gap_
:
0
;
if
(
j
>
0
)
{
start
=
(
time_stamp
[
j
]
-
time_stamp
[
j
-
1
])
*
frame_shift_in_ms
()
<
time_stamp_gap_
?
(
time_stamp
[
j
-
1
]
+
time_stamp
[
j
])
/
2
*
frame_shift_in_ms
()
:
start
;
}
int
end
=
time_stamp
[
j
]
*
frame_shift_in_ms
();
if
(
j
<
input
.
size
()
-
1
)
{
end
=
(
time_stamp
[
j
+
1
]
-
time_stamp
[
j
])
*
frame_shift_in_ms
()
<
time_stamp_gap_
?
(
time_stamp
[
j
+
1
]
+
time_stamp
[
j
])
/
2
*
frame_shift_in_ms
()
:
end
;
}
WordPiece
word_piece
(
word
,
offset
+
start
,
offset
+
end
);
path
.
word_pieces
.
emplace_back
(
word_piece
);
}
}
if
(
post_processor_
!=
nullptr
)
{
path
.
sentence
=
post_processor_
->
Process
(
path
.
sentence
,
finish
);
}
result_
.
emplace_back
(
path
);
}
if
(
DecodedSomething
())
{
VLOG
(
1
)
<<
"Partial CTC result "
<<
result_
[
0
].
sentence
;
}
}
void
AsrDecoder
::
AttentionRescoring
()
{
searcher_
->
FinalizeSearch
();
UpdateResult
(
true
);
// No need to do rescoring
if
(
0.0
==
opts_
.
rescoring_weight
)
{
return
;
}
// Inputs() returns N-best input ids, which is the basic unit for rescoring
// In CtcPrefixBeamSearch, inputs are the same to outputs
const
auto
&
hypotheses
=
searcher_
->
Inputs
();
int
num_hyps
=
hypotheses
.
size
();
if
(
num_hyps
<=
0
)
{
return
;
}
std
::
vector
<
float
>
rescoring_score
;
model_
->
AttentionRescoring
(
hypotheses
,
opts_
.
reverse_weight
,
&
rescoring_score
);
// Combine ctc score and rescoring score
for
(
size_t
i
=
0
;
i
<
num_hyps
;
++
i
)
{
result_
[
i
].
score
=
opts_
.
rescoring_weight
*
rescoring_score
[
i
]
+
opts_
.
ctc_weight
*
result_
[
i
].
score
;
}
std
::
sort
(
result_
.
begin
(),
result_
.
end
(),
DecodeResult
::
CompareFunc
);
}
}
// namespace wenet
runtime/core/decoder/asr_decoder.h
0 → 100644
View file @
764b3a75
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang, Di Wu)
// 2022 Binbin Zhang (binbzha@qq.com)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_ASR_DECODER_H_
#define DECODER_ASR_DECODER_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "fst/fstlib.h"
#include "fst/symbol-table.h"
#include "decoder/asr_model.h"
#include "decoder/context_graph.h"
#include "decoder/ctc_endpoint.h"
#include "decoder/ctc_prefix_beam_search.h"
#include "decoder/ctc_wfst_beam_search.h"
#include "decoder/search_interface.h"
#include "frontend/feature_pipeline.h"
#include "post_processor/post_processor.h"
#include "utils/utils.h"
namespace
wenet
{
struct
DecodeOptions
{
// chunk_size is the frame number of one chunk after subsampling.
// e.g. if subsample rate is 4 and chunk_size = 16, the frames in
// one chunk are 64 = 16*4
int
chunk_size
=
16
;
int
num_left_chunks
=
-
1
;
// final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;
// rescoring_score = left_to_right_score * (1 - reverse_weight) +
// right_to_left_score * reverse_weight
// Please note the concept of ctc_scores in the following two search
// methods are different.
// For CtcPrefixBeamSearch, it's a sum(prefix) score + context score
// For CtcWfstBeamSearch, it's a max(viterbi) path score + context score
// So we should carefully set ctc_weight according to the search methods.
float
ctc_weight
=
0.5
;
float
rescoring_weight
=
1.0
;
float
reverse_weight
=
0.0
;
CtcEndpointConfig
ctc_endpoint_config
;
CtcPrefixBeamSearchOptions
ctc_prefix_search_opts
;
CtcWfstBeamSearchOptions
ctc_wfst_search_opts
;
};
struct
WordPiece
{
std
::
string
word
;
int
start
=
-
1
;
int
end
=
-
1
;
WordPiece
(
std
::
string
word
,
int
start
,
int
end
)
:
word
(
std
::
move
(
word
)),
start
(
start
),
end
(
end
)
{}
};
struct
DecodeResult
{
float
score
=
-
kFloatMax
;
std
::
string
sentence
;
std
::
vector
<
WordPiece
>
word_pieces
;
static
bool
CompareFunc
(
const
DecodeResult
&
a
,
const
DecodeResult
&
b
)
{
return
a
.
score
>
b
.
score
;
}
};
enum
DecodeState
{
kEndBatch
=
0x00
,
// End of current decoding batch, normal case
kEndpoint
=
0x01
,
// Endpoint is detected
kEndFeats
=
0x02
,
// All feature is decoded
kWaitFeats
=
0x03
// Feat is not enough for one chunk inference, wait
};
// DecodeResource is thread safe, which can be shared for multiple
// decoding threads
struct
DecodeResource
{
std
::
shared_ptr
<
AsrModel
>
model
=
nullptr
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table
=
nullptr
;
std
::
shared_ptr
<
fst
::
Fst
<
fst
::
StdArc
>>
fst
=
nullptr
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
unit_table
=
nullptr
;
std
::
shared_ptr
<
ContextGraph
>
context_graph
=
nullptr
;
std
::
shared_ptr
<
PostProcessor
>
post_processor
=
nullptr
;
};
// Torch ASR decoder
class
AsrDecoder
{
public:
AsrDecoder
(
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline
,
std
::
shared_ptr
<
DecodeResource
>
resource
,
const
DecodeOptions
&
opts
);
// @param block: if true, block when feature is not enough for one chunk
// inference. Otherwise, return kWaitFeats.
DecodeState
Decode
(
bool
block
=
true
);
void
Rescoring
();
void
Reset
();
void
ResetContinuousDecoding
();
bool
DecodedSomething
()
const
{
return
!
result_
.
empty
()
&&
!
result_
[
0
].
sentence
.
empty
();
}
// This method is used for time benchmark
int
num_frames_in_current_chunk
()
const
{
return
num_frames_in_current_chunk_
;
}
int
frame_shift_in_ms
()
const
{
return
model_
->
subsampling_rate
()
*
feature_pipeline_
->
config
().
frame_shift
*
1000
/
feature_pipeline_
->
config
().
sample_rate
;
}
int
feature_frame_shift_in_ms
()
const
{
return
feature_pipeline_
->
config
().
frame_shift
*
1000
/
feature_pipeline_
->
config
().
sample_rate
;
}
const
std
::
vector
<
DecodeResult
>&
result
()
const
{
return
result_
;
}
private:
DecodeState
AdvanceDecoding
(
bool
block
=
true
);
void
AttentionRescoring
();
void
UpdateResult
(
bool
finish
=
false
);
std
::
shared_ptr
<
FeaturePipeline
>
feature_pipeline_
;
std
::
shared_ptr
<
AsrModel
>
model_
;
std
::
shared_ptr
<
PostProcessor
>
post_processor_
;
std
::
shared_ptr
<
fst
::
Fst
<
fst
::
StdArc
>>
fst_
=
nullptr
;
// output symbol table
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table_
;
// e2e unit symbol table
std
::
shared_ptr
<
fst
::
SymbolTable
>
unit_table_
=
nullptr
;
const
DecodeOptions
&
opts_
;
// cache feature
bool
start_
=
false
;
// For continuous decoding
int
num_frames_
=
0
;
int
global_frame_offset_
=
0
;
const
int
time_stamp_gap_
=
100
;
// timestamp gap between words in a sentence
std
::
unique_ptr
<
SearchInterface
>
searcher_
;
std
::
unique_ptr
<
CtcEndpoint
>
ctc_endpointer_
;
int
num_frames_in_current_chunk_
=
0
;
std
::
vector
<
DecodeResult
>
result_
;
public:
WENET_DISALLOW_COPY_AND_ASSIGN
(
AsrDecoder
);
};
}
// namespace wenet
#endif // DECODER_ASR_DECODER_H_
runtime/core/decoder/asr_model.cc
0 → 100644
View file @
764b3a75
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Author: binbin.zhang@horizon.ai (Binbin Zhang)
#include "decoder/asr_model.h"
#include <memory>
#include <utility>
namespace
wenet
{
int
AsrModel
::
num_frames_for_chunk
(
bool
start
)
const
{
int
num_required_frames
=
0
;
if
(
chunk_size_
>
0
)
{
if
(
!
start
)
{
// First batch
int
context
=
right_context_
+
1
;
// Add current frame
num_required_frames
=
(
chunk_size_
-
1
)
*
subsampling_rate_
+
context
;
}
else
{
num_required_frames
=
chunk_size_
*
subsampling_rate_
;
}
}
else
{
num_required_frames
=
std
::
numeric_limits
<
int
>::
max
();
}
return
num_required_frames
;
}
void
AsrModel
::
CacheFeature
(
const
std
::
vector
<
std
::
vector
<
float
>>&
chunk_feats
)
{
// Cache feature for next chunk
const
int
cached_feature_size
=
1
+
right_context_
-
subsampling_rate_
;
if
(
chunk_feats
.
size
()
>=
cached_feature_size
)
{
// TODO(Binbin Zhang): Only deal the case when
// chunk_feats.size() > cached_feature_size here, and it's consistent
// with our current model, refine it later if we have new model or
// new requirements
cached_feature_
.
resize
(
cached_feature_size
);
for
(
int
i
=
0
;
i
<
cached_feature_size
;
++
i
)
{
cached_feature_
[
i
]
=
chunk_feats
[
chunk_feats
.
size
()
-
cached_feature_size
+
i
];
}
}
}
void
AsrModel
::
ForwardEncoder
(
const
std
::
vector
<
std
::
vector
<
float
>>&
chunk_feats
,
std
::
vector
<
std
::
vector
<
float
>>*
ctc_prob
)
{
ctc_prob
->
clear
();
int
num_frames
=
cached_feature_
.
size
()
+
chunk_feats
.
size
();
if
(
num_frames
>=
right_context_
+
1
)
{
this
->
ForwardEncoderFunc
(
chunk_feats
,
ctc_prob
);
this
->
CacheFeature
(
chunk_feats
);
}
}
}
// namespace wenet
runtime/core/decoder/asr_model.h
0 → 100644
View file @
764b3a75
// Copyright 2022 Horizon Robotics. All Rights Reserved.
// Author: binbin.zhang@horizon.ai (Binbin Zhang)
#ifndef DECODER_ASR_MODEL_H_
#define DECODER_ASR_MODEL_H_
#include <limits>
#include <memory>
#include <string>
#include <vector>
#include "utils/timer.h"
#include "utils/utils.h"
namespace
wenet
{
class
AsrModel
{
public:
virtual
int
right_context
()
const
{
return
right_context_
;
}
virtual
int
subsampling_rate
()
const
{
return
subsampling_rate_
;
}
virtual
int
sos
()
const
{
return
sos_
;
}
virtual
int
eos
()
const
{
return
eos_
;
}
virtual
bool
is_bidirectional_decoder
()
const
{
return
is_bidirectional_decoder_
;
}
virtual
int
offset
()
const
{
return
offset_
;
}
// If chunk_size > 0, streaming case. Otherwise, none streaming case
virtual
void
set_chunk_size
(
int
chunk_size
)
{
chunk_size_
=
chunk_size
;
}
virtual
void
set_num_left_chunks
(
int
num_left_chunks
)
{
num_left_chunks_
=
num_left_chunks
;
}
// start: if it is the start chunk of one sentence
virtual
int
num_frames_for_chunk
(
bool
start
)
const
;
virtual
void
Reset
()
=
0
;
virtual
void
ForwardEncoder
(
const
std
::
vector
<
std
::
vector
<
float
>>&
chunk_feats
,
std
::
vector
<
std
::
vector
<
float
>>*
ctc_prob
);
virtual
void
AttentionRescoring
(
const
std
::
vector
<
std
::
vector
<
int
>>&
hyps
,
float
reverse_weight
,
std
::
vector
<
float
>*
rescoring_score
)
=
0
;
virtual
std
::
shared_ptr
<
AsrModel
>
Copy
()
const
=
0
;
protected:
virtual
void
ForwardEncoderFunc
(
const
std
::
vector
<
std
::
vector
<
float
>>&
chunk_feats
,
std
::
vector
<
std
::
vector
<
float
>>*
ctc_prob
)
=
0
;
virtual
void
CacheFeature
(
const
std
::
vector
<
std
::
vector
<
float
>>&
chunk_feats
);
int
right_context_
=
1
;
int
subsampling_rate_
=
1
;
int
sos_
=
0
;
int
eos_
=
0
;
bool
is_bidirectional_decoder_
=
false
;
int
chunk_size_
=
16
;
int
num_left_chunks_
=
-
1
;
// -1 means all left chunks
int
offset_
=
0
;
std
::
vector
<
std
::
vector
<
float
>>
cached_feature_
;
};
}
// namespace wenet
#endif // DECODER_ASR_MODEL_H_
runtime/core/decoder/context_graph.cc
0 → 100644
View file @
764b3a75
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/context_graph.h"
#include <utility>
#include "fst/determinize.h"
#include "utils/string.h"
#include "utils/utils.h"
namespace
wenet
{
ContextGraph
::
ContextGraph
(
ContextConfig
config
)
:
config_
(
config
)
{}
void
ContextGraph
::
BuildContextGraph
(
const
std
::
vector
<
std
::
string
>&
query_contexts
,
const
std
::
shared_ptr
<
fst
::
SymbolTable
>&
symbol_table
)
{
CHECK
(
symbol_table
!=
nullptr
)
<<
"Symbols table should not be nullptr!"
;
start_tag_id_
=
symbol_table
->
AddSymbol
(
"<context>"
);
end_tag_id_
=
symbol_table
->
AddSymbol
(
"</context>"
);
symbol_table_
=
symbol_table
;
if
(
query_contexts
.
empty
())
{
if
(
graph_
!=
nullptr
)
graph_
.
reset
();
return
;
}
std
::
unique_ptr
<
fst
::
StdVectorFst
>
ofst
(
new
fst
::
StdVectorFst
());
// State 0 is the start state and the final state.
int
start_state
=
ofst
->
AddState
();
ofst
->
SetStart
(
start_state
);
ofst
->
SetFinal
(
start_state
,
fst
::
StdArc
::
Weight
::
One
());
LOG
(
INFO
)
<<
"Contexts count size: "
<<
query_contexts
.
size
();
int
count
=
0
;
for
(
const
auto
&
context
:
query_contexts
)
{
if
(
context
.
size
()
>
config_
.
max_context_length
)
{
LOG
(
INFO
)
<<
"Skip long context: "
<<
context
;
continue
;
}
if
(
++
count
>
config_
.
max_contexts
)
break
;
std
::
vector
<
std
::
string
>
words
;
// Split context to words by symbol table, and build the context graph.
bool
no_oov
=
SplitUTF8StringToWords
(
Trim
(
context
),
symbol_table
,
&
words
);
if
(
!
no_oov
)
{
LOG
(
WARNING
)
<<
"Ignore unknown word found during compilation."
;
continue
;
}
int
prev_state
=
start_state
;
int
next_state
=
start_state
;
float
escape_score
=
0
;
for
(
size_t
i
=
0
;
i
<
words
.
size
();
++
i
)
{
int
word_id
=
symbol_table_
->
Find
(
words
[
i
]);
float
score
=
(
i
*
config_
.
incremental_context_score
+
config_
.
context_score
)
*
UTF8StringLength
(
words
[
i
]);
next_state
=
(
i
<
words
.
size
()
-
1
)
?
ofst
->
AddState
()
:
start_state
;
ofst
->
AddArc
(
prev_state
,
fst
::
StdArc
(
word_id
,
word_id
,
score
,
next_state
));
// Add escape arc to clean the previous context score.
if
(
i
>
0
)
{
// ilabel and olabel of the escape arc is 0 (<epsilon>).
ofst
->
AddArc
(
prev_state
,
fst
::
StdArc
(
0
,
0
,
-
escape_score
,
start_state
));
}
prev_state
=
next_state
;
escape_score
+=
score
;
}
}
std
::
unique_ptr
<
fst
::
StdVectorFst
>
det_fst
(
new
fst
::
StdVectorFst
());
fst
::
Determinize
(
*
ofst
,
det_fst
.
get
());
graph_
=
std
::
move
(
det_fst
);
}
int
ContextGraph
::
GetNextState
(
int
cur_state
,
int
word_id
,
float
*
score
,
bool
*
is_start_boundary
,
bool
*
is_end_boundary
)
{
int
next_state
=
0
;
for
(
fst
::
ArcIterator
<
fst
::
StdFst
>
aiter
(
*
graph_
,
cur_state
);
!
aiter
.
Done
();
aiter
.
Next
())
{
const
fst
::
StdArc
&
arc
=
aiter
.
Value
();
if
(
arc
.
ilabel
==
0
)
{
// escape score, will be overwritten when ilabel equals to word id.
*
score
=
arc
.
weight
.
Value
();
}
else
if
(
arc
.
ilabel
==
word_id
)
{
next_state
=
arc
.
nextstate
;
*
score
=
arc
.
weight
.
Value
();
if
(
cur_state
==
0
)
{
*
is_start_boundary
=
true
;
}
if
(
graph_
->
Final
(
arc
.
nextstate
)
==
fst
::
StdArc
::
Weight
::
One
())
{
*
is_end_boundary
=
true
;
}
break
;
}
}
return
next_state
;
}
bool
ContextGraph
::
SplitUTF8StringToWords
(
const
std
::
string
&
str
,
const
std
::
shared_ptr
<
fst
::
SymbolTable
>&
symbol_table
,
std
::
vector
<
std
::
string
>*
words
)
{
std
::
vector
<
std
::
string
>
chars
;
SplitUTF8StringToChars
(
Trim
(
str
),
&
chars
);
bool
no_oov
=
true
;
for
(
size_t
start
=
0
;
start
<
chars
.
size
();)
{
for
(
size_t
end
=
chars
.
size
();
end
>
start
;
--
end
)
{
std
::
string
word
;
for
(
size_t
i
=
start
;
i
<
end
;
i
++
)
{
word
+=
chars
[
i
];
}
// Skip space.
if
(
word
==
" "
)
{
start
=
end
;
continue
;
}
// Add '▁' at the beginning of English word.
if
(
IsAlpha
(
word
))
{
word
=
kSpaceSymbol
+
word
;
}
if
(
symbol_table
->
Find
(
word
)
!=
-
1
)
{
words
->
emplace_back
(
word
);
start
=
end
;
continue
;
}
if
(
end
==
start
+
1
)
{
++
start
;
no_oov
=
false
;
LOG
(
WARNING
)
<<
word
<<
" is oov."
;
}
}
}
return
no_oov
;
}
}
// namespace wenet
runtime/core/decoder/context_graph.h
0 → 100644
View file @
764b3a75
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CONTEXT_GRAPH_H_
#define DECODER_CONTEXT_GRAPH_H_
#include <memory>
#include <string>
#include <vector>
#include "fst/compose.h"
#include "fst/fst.h"
#include "fst/vector-fst.h"
namespace
wenet
{
using
StateId
=
fst
::
StdArc
::
StateId
;
struct
ContextConfig
{
int
max_contexts
=
5000
;
int
max_context_length
=
100
;
float
context_score
=
3.0
;
float
incremental_context_score
=
0.0
;
};
class
ContextGraph
{
public:
explicit
ContextGraph
(
ContextConfig
config
);
void
BuildContextGraph
(
const
std
::
vector
<
std
::
string
>&
query_context
,
const
std
::
shared_ptr
<
fst
::
SymbolTable
>&
symbol_table
);
int
GetNextState
(
int
cur_state
,
int
word_id
,
float
*
score
,
bool
*
is_start_boundary
,
bool
*
is_end_boundary
);
int
start_tag_id
()
{
return
start_tag_id_
;
}
int
end_tag_id
()
{
return
end_tag_id_
;
}
private:
bool
SplitUTF8StringToWords
(
const
std
::
string
&
str
,
const
std
::
shared_ptr
<
fst
::
SymbolTable
>&
symbol_table
,
std
::
vector
<
std
::
string
>*
words
);
int
start_tag_id_
=
-
1
;
int
end_tag_id_
=
-
1
;
ContextConfig
config_
;
std
::
shared_ptr
<
fst
::
SymbolTable
>
symbol_table_
=
nullptr
;
std
::
unique_ptr
<
fst
::
StdVectorFst
>
graph_
=
nullptr
;
DISALLOW_COPY_AND_ASSIGN
(
ContextGraph
);
};
}
// namespace wenet
#endif // DECODER_CONTEXT_GRAPH_H_
runtime/core/decoder/ctc_endpoint.cc
0 → 100644
View file @
764b3a75
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_endpoint.h"
#include <math.h>
#include <string>
#include <vector>
#include "utils/log.h"
namespace
wenet
{
CtcEndpoint
::
CtcEndpoint
(
const
CtcEndpointConfig
&
config
)
:
config_
(
config
)
{
Reset
();
}
void
CtcEndpoint
::
Reset
()
{
num_frames_decoded_
=
0
;
num_frames_trailing_blank_
=
0
;
}
static
bool
RuleActivated
(
const
CtcEndpointRule
&
rule
,
const
std
::
string
&
rule_name
,
bool
decoded_sth
,
int
trailing_silence
,
int
utterance_length
)
{
bool
ans
=
(
decoded_sth
||
!
rule
.
must_decoded_sth
)
&&
trailing_silence
>=
rule
.
min_trailing_silence
&&
utterance_length
>=
rule
.
min_utterance_length
;
if
(
ans
)
{
VLOG
(
2
)
<<
"Endpointing rule "
<<
rule_name
<<
" activated: "
<<
(
decoded_sth
?
"true"
:
"false"
)
<<
','
<<
trailing_silence
<<
','
<<
utterance_length
;
}
return
ans
;
}
bool
CtcEndpoint
::
IsEndpoint
(
const
std
::
vector
<
std
::
vector
<
float
>>&
ctc_log_probs
,
bool
decoded_something
)
{
for
(
int
t
=
0
;
t
<
ctc_log_probs
.
size
();
++
t
)
{
const
auto
&
logp_t
=
ctc_log_probs
[
t
];
float
blank_prob
=
expf
(
logp_t
[
config_
.
blank
]);
num_frames_decoded_
++
;
if
(
blank_prob
>
config_
.
blank_threshold
)
{
num_frames_trailing_blank_
++
;
}
else
{
num_frames_trailing_blank_
=
0
;
}
}
CHECK_GE
(
num_frames_decoded_
,
num_frames_trailing_blank_
);
CHECK_GT
(
frame_shift_in_ms_
,
0
);
int
utterance_length
=
num_frames_decoded_
*
frame_shift_in_ms_
;
int
trailing_silence
=
num_frames_trailing_blank_
*
frame_shift_in_ms_
;
if
(
RuleActivated
(
config_
.
rule1
,
"rule1"
,
decoded_something
,
trailing_silence
,
utterance_length
))
return
true
;
if
(
RuleActivated
(
config_
.
rule2
,
"rule2"
,
decoded_something
,
trailing_silence
,
utterance_length
))
return
true
;
if
(
RuleActivated
(
config_
.
rule3
,
"rule3"
,
decoded_something
,
trailing_silence
,
utterance_length
))
return
true
;
return
false
;
}
}
// namespace wenet
runtime/core/decoder/ctc_endpoint.h
0 → 100644
View file @
764b3a75
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CTC_ENDPOINT_H_
#define DECODER_CTC_ENDPOINT_H_
#include <vector>
namespace
wenet
{
struct
CtcEndpointRule
{
bool
must_decoded_sth
;
int
min_trailing_silence
;
int
min_utterance_length
;
CtcEndpointRule
(
bool
must_decoded_sth
=
true
,
int
min_trailing_silence
=
1000
,
int
min_utterance_length
=
0
)
:
must_decoded_sth
(
must_decoded_sth
),
min_trailing_silence
(
min_trailing_silence
),
min_utterance_length
(
min_utterance_length
)
{}
};
struct
CtcEndpointConfig
{
/// We consider blank as silence for purposes of endpointing.
int
blank
=
0
;
// blank id
float
blank_threshold
=
0.8
;
// blank threshold to be silence
/// We support three rules. We terminate decoding if ANY of these rules
/// evaluates to "true". If you want to add more rules, do it by changing this
/// code. If you want to disable a rule, you can set the silence-timeout for
/// that rule to a very large number.
/// rule1 times out after 5000 ms of silence, even if we decoded nothing.
CtcEndpointRule
rule1
;
/// rule2 times out after 1000 ms of silence after decoding something.
CtcEndpointRule
rule2
;
/// rule3 times out after the utterance is 20000 ms long, regardless of
/// anything else.
CtcEndpointRule
rule3
;
CtcEndpointConfig
()
:
rule1
(
false
,
5000
,
0
),
rule2
(
true
,
1000
,
0
),
rule3
(
false
,
0
,
20000
)
{}
};
class
CtcEndpoint
{
public:
explicit
CtcEndpoint
(
const
CtcEndpointConfig
&
config
);
void
Reset
();
/// This function returns true if this set of endpointing rules thinks we
/// should terminate decoding.
bool
IsEndpoint
(
const
std
::
vector
<
std
::
vector
<
float
>>&
ctc_log_probs
,
bool
decoded_something
);
void
frame_shift_in_ms
(
int
frame_shift_in_ms
)
{
frame_shift_in_ms_
=
frame_shift_in_ms
;
}
private:
CtcEndpointConfig
config_
;
int
frame_shift_in_ms_
=
-
1
;
int
num_frames_decoded_
=
0
;
int
num_frames_trailing_blank_
=
0
;
};
}
// namespace wenet
#endif // DECODER_CTC_ENDPOINT_H_
runtime/core/decoder/ctc_prefix_beam_search.cc
0 → 100644
View file @
764b3a75
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_prefix_beam_search.h"
#include <algorithm>
#include <tuple>
#include <unordered_map>
#include <utility>
#include "utils/log.h"
#include "utils/utils.h"
namespace
wenet
{
CtcPrefixBeamSearch
::
CtcPrefixBeamSearch
(
const
CtcPrefixBeamSearchOptions
&
opts
,
const
std
::
shared_ptr
<
ContextGraph
>&
context_graph
)
:
opts_
(
opts
),
context_graph_
(
context_graph
)
{
Reset
();
}
void
CtcPrefixBeamSearch
::
Reset
()
{
hypotheses_
.
clear
();
likelihood_
.
clear
();
cur_hyps_
.
clear
();
viterbi_likelihood_
.
clear
();
times_
.
clear
();
outputs_
.
clear
();
abs_time_step_
=
0
;
PrefixScore
prefix_score
;
prefix_score
.
s
=
0.0
;
prefix_score
.
ns
=
-
kFloatMax
;
prefix_score
.
v_s
=
0.0
;
prefix_score
.
v_ns
=
0.0
;
std
::
vector
<
int
>
empty
;
cur_hyps_
[
empty
]
=
prefix_score
;
outputs_
.
emplace_back
(
empty
);
hypotheses_
.
emplace_back
(
empty
);
likelihood_
.
emplace_back
(
prefix_score
.
total_score
());
times_
.
emplace_back
(
empty
);
}
static
bool
PrefixScoreCompare
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
a
,
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
b
)
{
return
a
.
second
.
total_score
()
>
b
.
second
.
total_score
();
}
void
CtcPrefixBeamSearch
::
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
)
{
const
std
::
vector
<
int
>&
input
=
prefix
.
first
;
const
std
::
vector
<
int
>&
start_boundaries
=
prefix
.
second
.
start_boundaries
;
const
std
::
vector
<
int
>&
end_boundaries
=
prefix
.
second
.
end_boundaries
;
std
::
vector
<
int
>
output
;
int
s
=
0
;
int
e
=
0
;
for
(
int
i
=
0
;
i
<
input
.
size
();
++
i
)
{
if
(
s
<
start_boundaries
.
size
()
&&
i
==
start_boundaries
[
s
])
{
output
.
emplace_back
(
context_graph_
->
start_tag_id
());
++
s
;
}
output
.
emplace_back
(
input
[
i
]);
if
(
e
<
end_boundaries
.
size
()
&&
i
==
end_boundaries
[
e
])
{
output
.
emplace_back
(
context_graph_
->
end_tag_id
());
++
e
;
}
}
outputs_
.
emplace_back
(
output
);
}
void
CtcPrefixBeamSearch
::
UpdateHypotheses
(
const
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>&
hpys
)
{
cur_hyps_
.
clear
();
outputs_
.
clear
();
hypotheses_
.
clear
();
likelihood_
.
clear
();
viterbi_likelihood_
.
clear
();
times_
.
clear
();
for
(
auto
&
item
:
hpys
)
{
cur_hyps_
[
item
.
first
]
=
item
.
second
;
UpdateOutputs
(
item
);
hypotheses_
.
emplace_back
(
std
::
move
(
item
.
first
));
likelihood_
.
emplace_back
(
item
.
second
.
total_score
());
viterbi_likelihood_
.
emplace_back
(
item
.
second
.
viterbi_score
());
times_
.
emplace_back
(
item
.
second
.
times
());
}
}
// Please refer https://robin1001.github.io/2020/12/11/ctc-search
// for how CTC prefix beam search works, and there is a simple graph demo in
// it.
void
CtcPrefixBeamSearch
::
Search
(
const
std
::
vector
<
std
::
vector
<
float
>>&
logp
)
{
if
(
logp
.
size
()
==
0
)
return
;
int
first_beam_size
=
std
::
min
(
static_cast
<
int
>
(
logp
[
0
].
size
()),
opts_
.
first_beam_size
);
for
(
int
t
=
0
;
t
<
logp
.
size
();
++
t
,
++
abs_time_step_
)
{
const
std
::
vector
<
float
>&
logp_t
=
logp
[
t
];
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixHash
>
next_hyps
;
// 1. First beam prune, only select topk candidates
std
::
vector
<
float
>
topk_score
;
std
::
vector
<
int32_t
>
topk_index
;
TopK
(
logp_t
,
first_beam_size
,
&
topk_score
,
&
topk_index
);
// 2. Token passing
for
(
int
i
=
0
;
i
<
topk_index
.
size
();
++
i
)
{
int
id
=
topk_index
[
i
];
auto
prob
=
topk_score
[
i
];
for
(
const
auto
&
it
:
cur_hyps_
)
{
const
std
::
vector
<
int
>&
prefix
=
it
.
first
;
const
PrefixScore
&
prefix_score
=
it
.
second
;
// If prefix doesn't exist in next_hyps, next_hyps[prefix] will insert
// PrefixScore(-inf, -inf) by default, since the default constructor
// of PrefixScore will set fields s(blank ending score) and
// ns(none blank ending score) to -inf, respectively.
if
(
id
==
opts_
.
blank
)
{
// Case 0: *a + ε => *a
PrefixScore
&
next_score
=
next_hyps
[
prefix
];
next_score
.
s
=
LogAdd
(
next_score
.
s
,
prefix_score
.
score
()
+
prob
);
next_score
.
v_s
=
prefix_score
.
viterbi_score
()
+
prob
;
next_score
.
times_s
=
prefix_score
.
times
();
// Prefix not changed, copy the context from prefix.
if
(
context_graph_
&&
!
next_score
.
has_context
)
{
next_score
.
CopyContext
(
prefix_score
);
next_score
.
has_context
=
true
;
}
}
else
if
(
!
prefix
.
empty
()
&&
id
==
prefix
.
back
())
{
// Case 1: *a + a => *a
PrefixScore
&
next_score1
=
next_hyps
[
prefix
];
next_score1
.
ns
=
LogAdd
(
next_score1
.
ns
,
prefix_score
.
ns
+
prob
);
if
(
next_score1
.
v_ns
<
prefix_score
.
v_ns
+
prob
)
{
next_score1
.
v_ns
=
prefix_score
.
v_ns
+
prob
;
if
(
next_score1
.
cur_token_prob
<
prob
)
{
next_score1
.
cur_token_prob
=
prob
;
next_score1
.
times_ns
=
prefix_score
.
times_ns
;
CHECK_GT
(
next_score1
.
times_ns
.
size
(),
0
);
next_score1
.
times_ns
.
back
()
=
abs_time_step_
;
}
}
if
(
context_graph_
&&
!
next_score1
.
has_context
)
{
next_score1
.
CopyContext
(
prefix_score
);
next_score1
.
has_context
=
true
;
}
// Case 2: *aε + a => *aa
std
::
vector
<
int
>
new_prefix
(
prefix
);
new_prefix
.
emplace_back
(
id
);
PrefixScore
&
next_score2
=
next_hyps
[
new_prefix
];
next_score2
.
ns
=
LogAdd
(
next_score2
.
ns
,
prefix_score
.
s
+
prob
);
if
(
next_score2
.
v_ns
<
prefix_score
.
v_s
+
prob
)
{
next_score2
.
v_ns
=
prefix_score
.
v_s
+
prob
;
next_score2
.
cur_token_prob
=
prob
;
next_score2
.
times_ns
=
prefix_score
.
times_s
;
next_score2
.
times_ns
.
emplace_back
(
abs_time_step_
);
}
if
(
context_graph_
&&
!
next_score2
.
has_context
)
{
// Prefix changed, calculate the context score.
next_score2
.
UpdateContext
(
context_graph_
,
prefix_score
,
id
,
prefix
.
size
());
next_score2
.
has_context
=
true
;
}
}
else
{
// Case 3: *a + b => *ab, *aε + b => *ab
std
::
vector
<
int
>
new_prefix
(
prefix
);
new_prefix
.
emplace_back
(
id
);
PrefixScore
&
next_score
=
next_hyps
[
new_prefix
];
next_score
.
ns
=
LogAdd
(
next_score
.
ns
,
prefix_score
.
score
()
+
prob
);
if
(
next_score
.
v_ns
<
prefix_score
.
viterbi_score
()
+
prob
)
{
next_score
.
v_ns
=
prefix_score
.
viterbi_score
()
+
prob
;
next_score
.
cur_token_prob
=
prob
;
next_score
.
times_ns
=
prefix_score
.
times
();
next_score
.
times_ns
.
emplace_back
(
abs_time_step_
);
}
if
(
context_graph_
&&
!
next_score
.
has_context
)
{
// Calculate the context score.
next_score
.
UpdateContext
(
context_graph_
,
prefix_score
,
id
,
prefix
.
size
());
next_score
.
has_context
=
true
;
}
}
}
}
// 3. Second beam prune, only keep top n best paths
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>
arr
(
next_hyps
.
begin
(),
next_hyps
.
end
());
int
second_beam_size
=
std
::
min
(
static_cast
<
int
>
(
arr
.
size
()),
opts_
.
second_beam_size
);
std
::
nth_element
(
arr
.
begin
(),
arr
.
begin
()
+
second_beam_size
,
arr
.
end
(),
PrefixScoreCompare
);
arr
.
resize
(
second_beam_size
);
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
PrefixScoreCompare
);
// 4. Update cur_hyps_ and get new result
UpdateHypotheses
(
arr
);
}
}
void
CtcPrefixBeamSearch
::
FinalizeSearch
()
{
UpdateFinalContext
();
}
void
CtcPrefixBeamSearch
::
UpdateFinalContext
()
{
if
(
context_graph_
==
nullptr
)
return
;
CHECK_EQ
(
hypotheses_
.
size
(),
cur_hyps_
.
size
());
CHECK_EQ
(
hypotheses_
.
size
(),
likelihood_
.
size
());
// We should backoff the context score/state when the context is
// not fully matched at the last time.
for
(
const
auto
&
prefix
:
hypotheses_
)
{
PrefixScore
&
prefix_score
=
cur_hyps_
[
prefix
];
if
(
prefix_score
.
context_state
!=
0
)
{
prefix_score
.
UpdateContext
(
context_graph_
,
prefix_score
,
0
,
prefix
.
size
());
}
}
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>
arr
(
cur_hyps_
.
begin
(),
cur_hyps_
.
end
());
std
::
sort
(
arr
.
begin
(),
arr
.
end
(),
PrefixScoreCompare
);
// Update cur_hyps_ and get new result
UpdateHypotheses
(
arr
);
}
}
// namespace wenet
runtime/core/decoder/ctc_prefix_beam_search.h
0 → 100644
View file @
764b3a75
// Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DECODER_CTC_PREFIX_BEAM_SEARCH_H_
#define DECODER_CTC_PREFIX_BEAM_SEARCH_H_
#include <memory>
#include <unordered_map>
#include <utility>
#include <vector>
#include "decoder/context_graph.h"
#include "decoder/search_interface.h"
#include "utils/utils.h"
namespace
wenet
{
struct
CtcPrefixBeamSearchOptions
{
int
blank
=
0
;
// blank id
int
first_beam_size
=
10
;
int
second_beam_size
=
10
;
};
struct
PrefixScore
{
float
s
=
-
kFloatMax
;
// blank ending score
float
ns
=
-
kFloatMax
;
// none blank ending score
float
v_s
=
-
kFloatMax
;
// viterbi blank ending score
float
v_ns
=
-
kFloatMax
;
// viterbi none blank ending score
float
cur_token_prob
=
-
kFloatMax
;
// prob of current token
std
::
vector
<
int
>
times_s
;
// times of viterbi blank path
std
::
vector
<
int
>
times_ns
;
// times of viterbi none blank path
float
score
()
const
{
return
LogAdd
(
s
,
ns
);
}
float
viterbi_score
()
const
{
return
v_s
>
v_ns
?
v_s
:
v_ns
;
}
const
std
::
vector
<
int
>&
times
()
const
{
return
v_s
>
v_ns
?
times_s
:
times_ns
;
}
bool
has_context
=
false
;
int
context_state
=
0
;
float
context_score
=
0
;
std
::
vector
<
int
>
start_boundaries
;
std
::
vector
<
int
>
end_boundaries
;
void
CopyContext
(
const
PrefixScore
&
prefix_score
)
{
context_state
=
prefix_score
.
context_state
;
context_score
=
prefix_score
.
context_score
;
start_boundaries
=
prefix_score
.
start_boundaries
;
end_boundaries
=
prefix_score
.
end_boundaries
;
}
void
UpdateContext
(
const
std
::
shared_ptr
<
ContextGraph
>&
context_graph
,
const
PrefixScore
&
prefix_score
,
int
word_id
,
int
prefix_len
)
{
this
->
CopyContext
(
prefix_score
);
float
score
=
0
;
bool
is_start_boundary
=
false
;
bool
is_end_boundary
=
false
;
context_state
=
context_graph
->
GetNextState
(
prefix_score
.
context_state
,
word_id
,
&
score
,
&
is_start_boundary
,
&
is_end_boundary
);
context_score
+=
score
;
if
(
is_start_boundary
)
start_boundaries
.
emplace_back
(
prefix_len
);
if
(
is_end_boundary
)
end_boundaries
.
emplace_back
(
prefix_len
);
}
float
total_score
()
const
{
return
score
()
+
context_score
;
}
};
struct
PrefixHash
{
size_t
operator
()(
const
std
::
vector
<
int
>&
prefix
)
const
{
size_t
hash_code
=
0
;
// here we use KB&DR hash code
for
(
int
id
:
prefix
)
{
hash_code
=
id
+
31
*
hash_code
;
}
return
hash_code
;
}
};
class
CtcPrefixBeamSearch
:
public
SearchInterface
{
public:
explicit
CtcPrefixBeamSearch
(
const
CtcPrefixBeamSearchOptions
&
opts
,
const
std
::
shared_ptr
<
ContextGraph
>&
context_graph
=
nullptr
);
void
Search
(
const
std
::
vector
<
std
::
vector
<
float
>>&
logp
)
override
;
void
Reset
()
override
;
void
FinalizeSearch
()
override
;
SearchType
Type
()
const
override
{
return
SearchType
::
kPrefixBeamSearch
;
}
void
UpdateOutputs
(
const
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>&
prefix
);
void
UpdateHypotheses
(
const
std
::
vector
<
std
::
pair
<
std
::
vector
<
int
>
,
PrefixScore
>>&
hpys
);
void
UpdateFinalContext
();
const
std
::
vector
<
float
>&
viterbi_likelihood
()
const
{
return
viterbi_likelihood_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Inputs
()
const
override
{
return
hypotheses_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Outputs
()
const
override
{
return
outputs_
;
}
const
std
::
vector
<
float
>&
Likelihood
()
const
override
{
return
likelihood_
;
}
const
std
::
vector
<
std
::
vector
<
int
>>&
Times
()
const
override
{
return
times_
;
}
private:
int
abs_time_step_
=
0
;
// N-best list and corresponding likelihood_, in sorted order
std
::
vector
<
std
::
vector
<
int
>>
hypotheses_
;
std
::
vector
<
float
>
likelihood_
;
std
::
vector
<
float
>
viterbi_likelihood_
;
std
::
vector
<
std
::
vector
<
int
>>
times_
;
std
::
unordered_map
<
std
::
vector
<
int
>
,
PrefixScore
,
PrefixHash
>
cur_hyps_
;
std
::
shared_ptr
<
ContextGraph
>
context_graph_
=
nullptr
;
// Outputs contain the hypotheses_ and tags like: <context> and </context>
std
::
vector
<
std
::
vector
<
int
>>
outputs_
;
const
CtcPrefixBeamSearchOptions
&
opts_
;
public:
WENET_DISALLOW_COPY_AND_ASSIGN
(
CtcPrefixBeamSearch
);
};
}
// namespace wenet
#endif // DECODER_CTC_PREFIX_BEAM_SEARCH_H_
runtime/core/decoder/ctc_wfst_beam_search.cc
0 → 100644
View file @
764b3a75
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "decoder/ctc_wfst_beam_search.h"
#include <utility>
namespace
wenet
{
void
DecodableTensorScaled
::
Reset
()
{
num_frames_ready_
=
0
;
done_
=
false
;
// Give an empty initialization, will throw error when
// AcceptLoglikes is not called
logp_
.
clear
();
}
void
DecodableTensorScaled
::
AcceptLoglikes
(
const
std
::
vector
<
float
>&
logp
)
{
++
num_frames_ready_
;
// TODO(Binbin Zhang): Avoid copy here
logp_
=
logp
;
}
float
DecodableTensorScaled
::
LogLikelihood
(
int32
frame
,
int32
index
)
{
CHECK_GT
(
index
,
0
);
CHECK_LT
(
frame
,
num_frames_ready_
);
return
scale_
*
logp_
[
index
-
1
];
}
bool
DecodableTensorScaled
::
IsLastFrame
(
int32
frame
)
const
{
CHECK_LT
(
frame
,
num_frames_ready_
);
return
done_
&&
(
frame
==
num_frames_ready_
-
1
);
}
int32
DecodableTensorScaled
::
NumIndices
()
const
{
LOG
(
FATAL
)
<<
"Not implement"
;
return
0
;
}
CtcWfstBeamSearch
::
CtcWfstBeamSearch
(
const
fst
::
Fst
<
fst
::
StdArc
>&
fst
,
const
CtcWfstBeamSearchOptions
&
opts
,
const
std
::
shared_ptr
<
ContextGraph
>&
context_graph
)
:
decodable_
(
opts
.
acoustic_scale
),
decoder_
(
fst
,
opts
,
context_graph
),
context_graph_
(
context_graph
),
opts_
(
opts
)
{
Reset
();
}
void
CtcWfstBeamSearch
::
Reset
()
{
num_frames_
=
0
;
decoded_frames_mapping_
.
clear
();
is_last_frame_blank_
=
false
;
last_best_
=
0
;
inputs_
.
clear
();
outputs_
.
clear
();
likelihood_
.
clear
();
times_
.
clear
();
decodable_
.
Reset
();
decoder_
.
InitDecoding
();
}
void
CtcWfstBeamSearch
::
Search
(
const
std
::
vector
<
std
::
vector
<
float
>>&
logp
)
{
if
(
0
==
logp
.
size
())
{
return
;
}
// Every time we get the log posterior, we decode it all before return
for
(
int
i
=
0
;
i
<
logp
.
size
();
i
++
)
{
float
blank_score
=
std
::
exp
(
logp
[
i
][
0
]);
if
(
blank_score
>
opts_
.
blank_skip_thresh
*
opts_
.
blank_scale
)
{
VLOG
(
3
)
<<
"skipping frame "
<<
num_frames_
<<
" score "
<<
blank_score
;
is_last_frame_blank_
=
true
;
last_frame_prob_
=
logp
[
i
];
}
else
{
// Get the best symbol
int
cur_best
=
std
::
max_element
(
logp
[
i
].
begin
(),
logp
[
i
].
end
())
-
logp
[
i
].
begin
();
// Optional, adding one blank frame if we has skipped it in two same
// symbols
if
(
cur_best
!=
0
&&
is_last_frame_blank_
&&
cur_best
==
last_best_
)
{
decodable_
.
AcceptLoglikes
(
last_frame_prob_
);
decoder_
.
AdvanceDecoding
(
&
decodable_
,
1
);
decoded_frames_mapping_
.
push_back
(
num_frames_
-
1
);
VLOG
(
2
)
<<
"Adding blank frame at symbol "
<<
cur_best
;
}
last_best_
=
cur_best
;
decodable_
.
AcceptLoglikes
(
logp
[
i
]);
decoder_
.
AdvanceDecoding
(
&
decodable_
,
1
);
decoded_frames_mapping_
.
push_back
(
num_frames_
);
is_last_frame_blank_
=
false
;
}
num_frames_
++
;
}
// Get the best path
inputs_
.
clear
();
outputs_
.
clear
();
likelihood_
.
clear
();
if
(
decoded_frames_mapping_
.
size
()
>
0
)
{
inputs_
.
resize
(
1
);
outputs_
.
resize
(
1
);
likelihood_
.
resize
(
1
);
kaldi
::
Lattice
lat
;
decoder_
.
GetBestPath
(
&
lat
,
false
);
std
::
vector
<
int
>
alignment
;
kaldi
::
LatticeWeight
weight
;
fst
::
GetLinearSymbolSequence
(
lat
,
&
alignment
,
&
outputs_
[
0
],
&
weight
);
ConvertToInputs
(
alignment
,
&
inputs_
[
0
]);
RemoveContinuousTags
(
&
outputs_
[
0
]);
VLOG
(
3
)
<<
weight
.
Value1
()
<<
" "
<<
weight
.
Value2
();
likelihood_
[
0
]
=
-
(
weight
.
Value1
()
+
weight
.
Value2
());
}
}
void
CtcWfstBeamSearch
::
FinalizeSearch
()
{
decodable_
.
SetFinish
();
decoder_
.
FinalizeDecoding
();
inputs_
.
clear
();
outputs_
.
clear
();
likelihood_
.
clear
();
times_
.
clear
();
if
(
decoded_frames_mapping_
.
size
()
>
0
)
{
std
::
vector
<
kaldi
::
Lattice
>
nbest_lats
;
if
(
opts_
.
nbest
==
1
)
{
kaldi
::
Lattice
lat
;
decoder_
.
GetBestPath
(
&
lat
,
true
);
nbest_lats
.
push_back
(
std
::
move
(
lat
));
}
else
{
// Get N-best path by lattice(CompactLattice)
kaldi
::
CompactLattice
clat
;
decoder_
.
GetLattice
(
&
clat
,
true
);
kaldi
::
Lattice
lat
,
nbest_lat
;
fst
::
ConvertLattice
(
clat
,
&
lat
);
// TODO(Binbin Zhang): it's n-best word lists here, not character n-best
fst
::
ShortestPath
(
lat
,
&
nbest_lat
,
opts_
.
nbest
);
fst
::
ConvertNbestToVector
(
nbest_lat
,
&
nbest_lats
);
}
int
nbest
=
nbest_lats
.
size
();
inputs_
.
resize
(
nbest
);
outputs_
.
resize
(
nbest
);
likelihood_
.
resize
(
nbest
);
times_
.
resize
(
nbest
);
for
(
int
i
=
0
;
i
<
nbest
;
i
++
)
{
kaldi
::
LatticeWeight
weight
;
std
::
vector
<
int
>
alignment
;
fst
::
GetLinearSymbolSequence
(
nbest_lats
[
i
],
&
alignment
,
&
outputs_
[
i
],
&
weight
);
ConvertToInputs
(
alignment
,
&
inputs_
[
i
],
&
times_
[
i
]);
RemoveContinuousTags
(
&
outputs_
[
i
]);
likelihood_
[
i
]
=
-
(
weight
.
Value1
()
+
weight
.
Value2
());
}
}
}
void
CtcWfstBeamSearch
::
ConvertToInputs
(
const
std
::
vector
<
int
>&
alignment
,
std
::
vector
<
int
>*
input
,
std
::
vector
<
int
>*
time
)
{
input
->
clear
();
if
(
time
!=
nullptr
)
time
->
clear
();
for
(
int
cur
=
0
;
cur
<
alignment
.
size
();
++
cur
)
{
// ignore blank
if
(
alignment
[
cur
]
-
1
==
0
)
continue
;
// merge continuous same label
if
(
cur
>
0
&&
alignment
[
cur
]
==
alignment
[
cur
-
1
])
continue
;
input
->
push_back
(
alignment
[
cur
]
-
1
);
if
(
time
!=
nullptr
)
{
time
->
push_back
(
decoded_frames_mapping_
[
cur
]);
}
}
}
void
CtcWfstBeamSearch
::
RemoveContinuousTags
(
std
::
vector
<
int
>*
output
)
{
if
(
context_graph_
)
{
for
(
auto
it
=
output
->
begin
();
it
!=
output
->
end
();)
{
if
(
*
it
==
context_graph_
->
start_tag_id
()
||
*
it
==
context_graph_
->
end_tag_id
())
{
if
(
it
+
1
!=
output
->
end
()
&&
*
it
==
*
(
it
+
1
))
{
it
=
output
->
erase
(
it
);
continue
;
}
}
++
it
;
}
}
}
}
// namespace wenet
Prev
1
…
17
18
19
20
21
22
23
24
25
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment