Commit 39b6343d authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Migrate CTC decoder code (#2580)

Summary:
This commit gets rid of our copy of CTC decoder code and
replace it with upstream Flashlight-Text repo.

Pull Request resolved: https://github.com/pytorch/audio/pull/2580

Reviewed By: carolineechen

Differential Revision: D38244906

Pulled By: mthrok

fbshipit-source-id: d274240fc67675552d19ff35e9a363b9b9048721
parent 919fd0c4
...@@ -3,5 +3,8 @@ ...@@ -3,5 +3,8 @@
url = https://github.com/kaldi-asr/kaldi url = https://github.com/kaldi-asr/kaldi
ignore = dirty ignore = dirty
[submodule "third_party/kenlm/submodule"] [submodule "third_party/kenlm/submodule"]
path = third_party/kenlm/submodule path = third_party/kenlm/kenlm
url = https://github.com/kpu/kenlm url = https://github.com/kpu/kenlm
[submodule "flashlight-text"]
path = third_party/flashlight-text/submodule
url = https://github.com/flashlight/text
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fvisibility=hidden") set(CMAKE_CXX_VISIBILITY_PRESET hidden)
################################################################################ ################################################################################
# sox # sox
...@@ -22,4 +22,5 @@ if (BUILD_CTC_DECODER) ...@@ -22,4 +22,5 @@ if (BUILD_CTC_DECODER)
add_subdirectory(bzip2) add_subdirectory(bzip2)
add_subdirectory(lzma) add_subdirectory(lzma)
add_subdirectory(kenlm) add_subdirectory(kenlm)
add_subdirectory(flashlight-text)
endif() endif()
# Custom CMakeLists for building flashlight-text decoder
#
# The main difference from upstream native CMakeLists from flashlight-text.
#
# 1. Build compression libraries statically and make KenLM self-contained
# 2. Build KenLM without Boost by compiling only what is used by flashlight-text
# 3. Build KenLM and flashlight-text in one go (not required, but nice-to-have feature)
# 4. Tweak the location of bindings so that its easier for TorchAudio build process to pick up.
# (the upstream CMakeLists.txt does not install them in the same location as libflashlight-text)
# 5. Tweak the name of bindings. (remove suffix like cpython-37m-darwin)
set(CMAKE_CXX_VISIBILITY_PRESET default)
set(
libflashlight_src
submodule/flashlight/lib/text/decoder/Utils.cpp
submodule/flashlight/lib/text/decoder/lm/KenLM.cpp
submodule/flashlight/lib/text/decoder/lm/ZeroLM.cpp
submodule/flashlight/lib/text/decoder/lm/ConvLM.cpp
submodule/flashlight/lib/text/decoder/LexiconDecoder.cpp
submodule/flashlight/lib/text/decoder/LexiconFreeDecoder.cpp
submodule/flashlight/lib/text/decoder/LexiconFreeSeq2SeqDecoder.cpp
submodule/flashlight/lib/text/decoder/LexiconSeq2SeqDecoder.cpp
submodule/flashlight/lib/text/decoder/Trie.cpp
submodule/flashlight/lib/text/String.cpp
submodule/flashlight/lib/text/dictionary/Utils.cpp
submodule/flashlight/lib/text/dictionary/Dictionary.cpp
)
torchaudio_library(
libflashlight-text
"${libflashlight_src}"
submodule
""
FL_TEXT_USE_KENLM
)
# TODO: update torchaudio_library to handle private links
target_link_libraries(
libflashlight-text
PRIVATE
kenlm)
if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
torchaudio_extension(
flashlight_lib_text_dictionary
submodule/bindings/python/flashlight/lib/text/_dictionary.cpp
submodule
libflashlight-text
""
)
torchaudio_extension(
flashlight_lib_text_decoder
submodule/bindings/python/flashlight/lib/text/_decoder.cpp
submodule
libflashlight-text
FL_TEXT_USE_KENLM
)
endif()
Subproject commit 98028c7da83d66c2aba6f5f8708c063d266ca5a4
set( set(
KENLM_UTIL_SOURCES KENLM_UTIL_SOURCES
submodule/util/bit_packing.cc kenlm/util/bit_packing.cc
submodule/util/double-conversion/bignum.cc kenlm/util/double-conversion/bignum.cc
submodule/util/double-conversion/bignum-dtoa.cc kenlm/util/double-conversion/bignum-dtoa.cc
submodule/util/double-conversion/cached-powers.cc kenlm/util/double-conversion/cached-powers.cc
submodule/util/double-conversion/diy-fp.cc kenlm/util/double-conversion/diy-fp.cc
submodule/util/double-conversion/double-conversion.cc kenlm/util/double-conversion/double-conversion.cc
submodule/util/double-conversion/fast-dtoa.cc kenlm/util/double-conversion/fast-dtoa.cc
submodule/util/double-conversion/fixed-dtoa.cc kenlm/util/double-conversion/fixed-dtoa.cc
submodule/util/double-conversion/strtod.cc kenlm/util/double-conversion/strtod.cc
submodule/util/ersatz_progress.cc kenlm/util/ersatz_progress.cc
submodule/util/exception.cc kenlm/util/exception.cc
submodule/util/file.cc kenlm/util/file.cc
submodule/util/file_piece.cc kenlm/util/file_piece.cc
submodule/util/float_to_string.cc kenlm/util/float_to_string.cc
submodule/util/integer_to_string.cc kenlm/util/integer_to_string.cc
submodule/util/mmap.cc kenlm/util/mmap.cc
submodule/util/murmur_hash.cc kenlm/util/murmur_hash.cc
submodule/util/pool.cc kenlm/util/pool.cc
submodule/util/read_compressed.cc kenlm/util/read_compressed.cc
submodule/util/scoped.cc kenlm/util/scoped.cc
submodule/util/spaces.cc kenlm/util/spaces.cc
submodule/util/string_piece.cc kenlm/util/string_piece.cc
) )
set( set(
KENLM_SOURCES KENLM_SOURCES
submodule/lm/bhiksha.cc kenlm/lm/bhiksha.cc
submodule/lm/binary_format.cc kenlm/lm/binary_format.cc
submodule/lm/config.cc kenlm/lm/config.cc
submodule/lm/lm_exception.cc kenlm/lm/lm_exception.cc
submodule/lm/model.cc kenlm/lm/model.cc
submodule/lm/quantize.cc kenlm/lm/quantize.cc
submodule/lm/read_arpa.cc kenlm/lm/read_arpa.cc
submodule/lm/search_hashed.cc kenlm/lm/search_hashed.cc
submodule/lm/search_trie.cc kenlm/lm/search_trie.cc
submodule/lm/trie.cc kenlm/lm/trie.cc
submodule/lm/trie_sort.cc kenlm/lm/trie_sort.cc
submodule/lm/value_build.cc kenlm/lm/value_build.cc
submodule/lm/virtual_interface.cc kenlm/lm/virtual_interface.cc
submodule/lm/vocab.cc kenlm/lm/vocab.cc
) )
add_library( add_library(
...@@ -51,8 +51,7 @@ add_library( ...@@ -51,8 +51,7 @@ add_library(
target_include_directories( target_include_directories(
kenlm kenlm
BEFORE BEFORE
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/../install/include" PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}"
PUBLIC submodule
) )
target_compile_definitions( target_compile_definitions(
...@@ -65,6 +64,7 @@ target_compile_definitions( ...@@ -65,6 +64,7 @@ target_compile_definitions(
target_link_libraries( target_link_libraries(
kenlm kenlm
PRIVATE
zlib zlib
bzip2 bzip2
lzma lzma
......
...@@ -51,8 +51,9 @@ def get_ext_modules(): ...@@ -51,8 +51,9 @@ def get_ext_modules():
if _BUILD_CTC_DECODER: if _BUILD_CTC_DECODER:
modules.extend( modules.extend(
[ [
Extension(name="torchaudio.lib.libtorchaudio_decoder", sources=[]), Extension(name="torchaudio.lib.libflashlight-text", sources=[]),
Extension(name="torchaudio._torchaudio_decoder", sources=[]), Extension(name="torchaudio.flashlight_lib_text_decoder", sources=[]),
Extension(name="torchaudio.flashlight_lib_text_dictionary", sources=[]),
] ]
) )
if _USE_FFMPEG: if _USE_FFMPEG:
......
...@@ -119,6 +119,7 @@ endif() ...@@ -119,6 +119,7 @@ endif()
#------------------------------------------------------------------------------# #------------------------------------------------------------------------------#
# END OF CUSTOMIZATION LOGICS # END OF CUSTOMIZATION LOGICS
#------------------------------------------------------------------------------# #------------------------------------------------------------------------------#
torchaudio_library( torchaudio_library(
libtorchaudio libtorchaudio
"${LIBTORCHAUDIO_SOURCES}" "${LIBTORCHAUDIO_SOURCES}"
...@@ -127,38 +128,6 @@ torchaudio_library( ...@@ -127,38 +128,6 @@ torchaudio_library(
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}" "${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
) )
################################################################################
# libtorchaudio_decoder.so
################################################################################
if (BUILD_CTC_DECODER)
set(
LIBTORCHAUDIO_DECODER_SOURCES
decoder/src/decoder/LexiconDecoder.cpp
decoder/src/decoder/LexiconFreeDecoder.cpp
decoder/src/decoder/Trie.cpp
decoder/src/decoder/Utils.cpp
decoder/src/decoder/lm/KenLM.cpp
decoder/src/decoder/lm/ZeroLM.cpp
decoder/src/dictionary/Dictionary.cpp
decoder/src/dictionary/String.cpp
decoder/src/dictionary/System.cpp
decoder/src/dictionary/Utils.cpp
)
set(
LIBTORCHAUDIO_DECODER_DEFINITIONS
BUILD_CTC_DECODER
)
torchaudio_library(
libtorchaudio_decoder
"${LIBTORCHAUDIO_DECODER_SOURCES}"
"${PROJECT_SOURCE_DIR}"
"torch;kenlm"
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS};${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
# TODO: Add libtorchaudio_decoder
if (APPLE) if (APPLE)
set(TORCHAUDIO_LIBRARY libtorchaudio CACHE INTERNAL "") set(TORCHAUDIO_LIBRARY libtorchaudio CACHE INTERNAL "")
else() else()
...@@ -224,15 +193,6 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) ...@@ -224,15 +193,6 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
libtorchaudio libtorchaudio
"${LIBTORCHAUDIO_COMPILE_DEFINITIONS}" "${LIBTORCHAUDIO_COMPILE_DEFINITIONS}"
) )
if(BUILD_CTC_DECODER)
torchaudio_extension(
_torchaudio_decoder
decoder/bindings/pybind.cpp
""
"libtorchaudio_decoder"
"${LIBTORCHAUDIO_DECODER_DEFINITIONS}"
)
endif()
if(USE_FFMPEG) if(USE_FFMPEG)
set( set(
FFMPEG_EXTENSION_SOURCES FFMPEG_EXTENSION_SOURCES
......
#include <torch/extension.h>
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h"
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
namespace py = pybind11;
using namespace torchaudio::lib::text;
using namespace py::literals;
/**
* Some hackery that lets pybind11 handle shared_ptr<void> (for old LMStatePtr).
* See: https://github.com/pybind/pybind11/issues/820
* PYBIND11_MAKE_OPAQUE(std::shared_ptr<void>);
* and inside PYBIND11_MODULE
* py::class_<std::shared_ptr<void>>(m, "encapsulated_data");
*/
namespace {
/**
* A pybind11 "alias type" for abstract class LM, allowing one to subclass LM
* with a custom LM defined purely in Python. For those who don't want to build
* with KenLM, or have their own custom LM implementation.
* See: https://pybind11.readthedocs.io/en/stable/advanced/classes.html
*
* TODO: ensure this works. Last time Jeff tried this there were slicing issues,
* see https://github.com/pybind/pybind11/issues/1546 for workarounds.
* This is low-pri since we assume most people can just build with KenLM.
*/
class PyLM : public LM {
using LM::LM;
// needed for pybind11 or else it won't compile
using LMOutput = std::pair<LMStatePtr, float>;
LMStatePtr start(bool startWithNothing) override {
PYBIND11_OVERLOAD_PURE(LMStatePtr, LM, start, startWithNothing);
}
LMOutput score(const LMStatePtr& state, const int usrTokenIdx) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, score, state, usrTokenIdx);
}
LMOutput finish(const LMStatePtr& state) override {
PYBIND11_OVERLOAD_PURE(LMOutput, LM, finish, state);
}
};
/**
* Using custom python LMState derived from LMState is not working with
* custom python LM (derived from PyLM) because we need to to custing of LMState
* in score and finish functions to the derived class
* (for example vie obj.__class__ = CustomPyLMSTate) which cause the error
* "TypeError: __class__ assignment: 'CustomPyLMState' deallocator differs
* from 'flashlight.text.decoder._decoder.LMState'"
* details see in https://github.com/pybind/pybind11/issues/1640
* To define custom LM you can introduce map inside LM which maps LMstate to
* additional state info (shared pointers pointing to the same underlying object
* will have the same id in python in functions score and finish)
*
* ```python
* from flashlight.lib.text.decoder import LM
* class MyPyLM(LM):
* mapping_states = dict() # store simple additional int for each state
*
* def __init__(self):
* LM.__init__(self)
*
* def start(self, start_with_nothing):
* state = LMState()
* self.mapping_states[state] = 0
* return state
*
* def score(self, state, index):
* outstate = state.child(index)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -numpy.random.random())
*
* def finish(self, state):
* outstate = state.child(-1)
* if outstate not in self.mapping_states:
* self.mapping_states[outstate] = self.mapping_states[state] + 1
* return (outstate, -1)
*```
*/
void LexiconDecoder_decodeStep(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
decoder.decodeStep(reinterpret_cast<const float*>(emissions), T, N);
}
std::vector<DecodeResult> LexiconDecoder_decode(
LexiconDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}
void LexiconFreeDecoder_decodeStep(
LexiconFreeDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
decoder.decodeStep(reinterpret_cast<const float*>(emissions), T, N);
}
std::vector<DecodeResult> LexiconFreeDecoder_decode(
LexiconFreeDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}
void Dictionary_addEntry_0(
Dictionary& dict,
const std::string& entry,
int idx) {
dict.addEntry(entry, idx);
}
void Dictionary_addEntry_1(Dictionary& dict, const std::string& entry) {
dict.addEntry(entry);
}
PYBIND11_MODULE(_torchaudio_decoder, m) {
#ifdef BUILD_CTC_DECODER
py::enum_<SmearingMode>(m, "_SmearingMode")
.value("NONE", SmearingMode::NONE)
.value("MAX", SmearingMode::MAX)
.value("LOGADD", SmearingMode::LOGADD);
py::class_<TrieNode, TrieNodePtr>(m, "_TrieNode")
.def(py::init<int>(), "idx"_a)
.def_readwrite("children", &TrieNode::children)
.def_readwrite("idx", &TrieNode::idx)
.def_readwrite("labels", &TrieNode::labels)
.def_readwrite("scores", &TrieNode::scores)
.def_readwrite("max_score", &TrieNode::maxScore);
py::class_<Trie, TriePtr>(m, "_Trie")
.def(py::init<int, int>(), "max_children"_a, "root_idx"_a)
.def("get_root", &Trie::getRoot)
.def("insert", &Trie::insert, "indices"_a, "label"_a, "score"_a)
.def("search", &Trie::search, "indices"_a)
.def("smear", &Trie::smear, "smear_mode"_a);
py::class_<LM, LMPtr, PyLM>(m, "_LM")
.def(py::init<>())
.def("start", &LM::start, "start_with_nothing"_a)
.def("score", &LM::score, "state"_a, "usr_token_idx"_a)
.def("finish", &LM::finish, "state"_a);
py::class_<LMState, LMStatePtr>(m, "_LMState")
.def(py::init<>())
.def_readwrite("children", &LMState::children)
.def("compare", &LMState::compare, "state"_a)
.def("child", &LMState::child<LMState>, "usr_index"_a);
py::class_<KenLM, KenLMPtr, LM>(m, "_KenLM")
.def(
py::init<const std::string&, const Dictionary&>(),
"path"_a,
"usr_token_dict"_a);
py::class_<ZeroLM, ZeroLMPtr, LM>(m, "_ZeroLM").def(py::init<>());
py::enum_<CriterionType>(m, "_CriterionType")
.value("ASG", CriterionType::ASG)
.value("CTC", CriterionType::CTC);
py::class_<LexiconDecoderOptions>(m, "_LexiconDecoderOptions")
.def(
py::init<
const int,
const int,
const double,
const double,
const double,
const double,
const double,
const bool,
const CriterionType>(),
"beam_size"_a,
"beam_size_token"_a,
"beam_threshold"_a,
"lm_weight"_a,
"word_score"_a,
"unk_score"_a,
"sil_score"_a,
"log_add"_a,
"criterion_type"_a)
.def_readwrite("beam_size", &LexiconDecoderOptions::beamSize)
.def_readwrite("beam_size_token", &LexiconDecoderOptions::beamSizeToken)
.def_readwrite("beam_threshold", &LexiconDecoderOptions::beamThreshold)
.def_readwrite("lm_weight", &LexiconDecoderOptions::lmWeight)
.def_readwrite("word_score", &LexiconDecoderOptions::wordScore)
.def_readwrite("unk_score", &LexiconDecoderOptions::unkScore)
.def_readwrite("sil_score", &LexiconDecoderOptions::silScore)
.def_readwrite("log_add", &LexiconDecoderOptions::logAdd)
.def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType);
py::class_<LexiconFreeDecoderOptions>(m, "_LexiconFreeDecoderOptions")
.def(
py::init<
const int,
const int,
const double,
const double,
const double,
const bool,
const CriterionType>(),
"beam_size"_a,
"beam_size_token"_a,
"beam_threshold"_a,
"lm_weight"_a,
"sil_score"_a,
"log_add"_a,
"criterion_type"_a)
.def_readwrite("beam_size", &LexiconFreeDecoderOptions::beamSize)
.def_readwrite(
"beam_size_token", &LexiconFreeDecoderOptions::beamSizeToken)
.def_readwrite(
"beam_threshold", &LexiconFreeDecoderOptions::beamThreshold)
.def_readwrite("lm_weight", &LexiconFreeDecoderOptions::lmWeight)
.def_readwrite("sil_score", &LexiconFreeDecoderOptions::silScore)
.def_readwrite("log_add", &LexiconFreeDecoderOptions::logAdd)
.def_readwrite(
"criterion_type", &LexiconFreeDecoderOptions::criterionType);
py::class_<DecodeResult>(m, "_DecodeResult")
.def(py::init<int>(), "length"_a)
.def_readwrite("score", &DecodeResult::score)
.def_readwrite("amScore", &DecodeResult::amScore)
.def_readwrite("lmScore", &DecodeResult::lmScore)
.def_readwrite("words", &DecodeResult::words)
.def_readwrite("tokens", &DecodeResult::tokens);
// NB: `decode` and `decodeStep` expect raw emissions pointers.
py::class_<LexiconDecoder>(m, "_LexiconDecoder")
.def(py::init<
LexiconDecoderOptions,
const TriePtr,
const LMPtr,
const int,
const int,
const int,
const std::vector<float>&,
const bool>())
.def("decode_begin", &LexiconDecoder::decodeBegin)
.def(
"decode_step",
&LexiconDecoder_decodeStep,
"emissions"_a,
"T"_a,
"N"_a)
.def("decode_end", &LexiconDecoder::decodeEnd)
.def("decode", &LexiconDecoder_decode, "emissions"_a, "T"_a, "N"_a)
.def("prune", &LexiconDecoder::prune, "look_back"_a = 0)
.def(
"get_best_hypothesis",
&LexiconDecoder::getBestHypothesis,
"look_back"_a = 0)
.def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis);
py::class_<LexiconFreeDecoder>(m, "_LexiconFreeDecoder")
.def(py::init<
LexiconFreeDecoderOptions,
const LMPtr,
const int,
const int,
const std::vector<float>&>())
.def("decode_begin", &LexiconFreeDecoder::decodeBegin)
.def(
"decode_step",
&LexiconFreeDecoder_decodeStep,
"emissions"_a,
"T"_a,
"N"_a)
.def("decode_end", &LexiconFreeDecoder::decodeEnd)
.def("decode", &LexiconFreeDecoder_decode, "emissions"_a, "T"_a, "N"_a)
.def("prune", &LexiconFreeDecoder::prune, "look_back"_a = 0)
.def(
"get_best_hypothesis",
&LexiconFreeDecoder::getBestHypothesis,
"look_back"_a = 0)
.def(
"get_all_final_hypothesis",
&LexiconFreeDecoder::getAllFinalHypothesis);
py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>())
.def(py::init<const std::vector<std::string>&>(), "tkns"_a)
.def(py::init<const std::string&>(), "filename"_a)
.def("entry_size", &Dictionary::entrySize)
.def("index_size", &Dictionary::indexSize)
.def("add_entry", &Dictionary_addEntry_0, "entry"_a, "idx"_a)
.def("add_entry", &Dictionary_addEntry_1, "entry"_a)
.def("get_entry", &Dictionary::getEntry, "idx"_a)
.def("set_default_index", &Dictionary::setDefaultIndex, "idx"_a)
.def("get_index", &Dictionary::getIndex, "entry"_a)
.def("contains", &Dictionary::contains, "entry"_a)
.def("is_contiguous", &Dictionary::isContiguous)
.def(
"map_entries_to_indices",
&Dictionary::mapEntriesToIndices,
"entries"_a)
.def(
"map_indices_to_entries",
&Dictionary::mapIndicesToEntries,
"indices"_a);
m.def("_create_word_dict", &createWordDict, "lexicon"_a);
m.def("_load_words", &loadWords, "filename"_a, "max_words"_a = -1);
#endif
}
} // namespace
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include "torchaudio/csrc/decoder/src/decoder/Utils.h"
namespace torchaudio {
namespace lib {
namespace text {
enum class CriterionType { ASG = 0, CTC = 1, S2S = 2 };
/**
* Decoder support two typical use cases:
* Offline manner:
* decoder.decode(someData) [returns all hypothesis (transcription)]
*
* Online manner:
* decoder.decodeBegin() [called only at the beginning of the stream]
* while (stream)
* decoder.decodeStep(someData) [one or more calls]
* decoder.getBestHypothesis() [returns the best hypothesis (transcription)]
* decoder.prune() [prunes the hypothesis space]
* decoder.decodeEnd() [called only at the end of the stream]
*
* Note: function decoder.prune() deletes hypothesis up until time when called
* to supports online decoding. It will also add a offset to the scores in beam
* to avoid underflow/overflow.
*
*/
class Decoder {
public:
Decoder() = default;
virtual ~Decoder() = default;
/* Initialize decoder before starting consume emissions */
virtual void decodeBegin() {}
/* Consume emissions in T x N chunks and increase the hypothesis space */
virtual void decodeStep(const float* emissions, int T, int N) = 0;
/* Finish up decoding after consuming all emissions */
virtual void decodeEnd() {}
/* Offline decode function, which consume all emissions at once */
virtual std::vector<DecodeResult> decode(
const float* emissions,
int T,
int N) {
decodeBegin();
decodeStep(emissions, T, N);
decodeEnd();
return getAllFinalHypothesis();
}
/* Prune the hypothesis space */
virtual void prune(int lookBack = 0) = 0;
/* Get the number of decoded frame in buffer */
virtual int nDecodedFramesInBuffer() const = 0;
/*
* Get the best completed hypothesis which is `lookBack` frames ahead the last
* one in buffer. For lexicon requiredd LMs, completed hypothesis means no
* partial word appears at the end.
*/
virtual DecodeResult getBestHypothesis(int lookBack = 0) const = 0;
/* Get all the final hypothesis */
virtual std::vector<DecodeResult> getAllFinalHypothesis() const = 0;
};
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include <unordered_map>
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
namespace torchaudio {
namespace lib {
namespace text {
void LexiconDecoder::decodeBegin() {
hyp_.clear();
hyp_.emplace(0, std::vector<LexiconDecoderState>());
/* note: the lm reset itself with :start() */
hyp_[0].emplace_back(
0.0, lm_->start(0), lexicon_->getRoot(), nullptr, sil_, -1);
nDecodedFrames_ = 0;
nPrunedFrames_ = 0;
}
void LexiconDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconDecoderState>());
}
}
std::vector<size_t> idx(N);
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconDecoderState& prevHyp : hyp_[startFrame + t]) {
const TrieNode* prevLex = prevHyp.lex;
const int prevIdx = prevHyp.token;
const float lexMaxScore =
prevLex == lexicon_->getRoot() ? 0 : prevLex->maxScore;
/* (1) Try children */
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
auto iter = prevLex->children.find(n);
if (iter == prevLex->children.end()) {
continue;
}
const TrieNodePtr& lex = iter->second;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
LMStatePtr lmState;
double lmScore = 0.;
if (isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second;
}
// We eat-up a new token
if (opt_.criterionType != CriterionType::CTC || prevHyp.prevBlank ||
n != prevIdx) {
if (!lex->children.empty()) {
if (!isLmToken_) {
lmState = prevHyp.lmState;
lmScore = lex->maxScore - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmState,
lex.get(),
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
// If we got a true word
for (auto label : lex->labels) {
if (prevLex == lexicon_->getRoot() && prevHyp.token == n) {
// This is to avoid an situation that, when there is word with
// single token spelling (e.g. X -> x) in the lexicon and token `x`
// is predicted in several consecutive frames, multiple word `X`
// will be emitted. This violates the property of CTC, where
// there must be an blank token in between to predict 2 identical
// tokens consecutively.
continue;
}
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, label);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.wordScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
label,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
// If we got an unknown word
if (lex->labels.empty() && (opt_.unkScore > kNegativeInfinity)) {
if (!isLmToken_) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, unk_);
lmState = lmStateScorePair.first;
lmScore = lmStateScorePair.second - lexMaxScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore + opt_.unkScore,
lmState,
lexicon_->getRoot(),
&prevHyp,
n,
unk_,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
}
}
/* (2) Try same lexicon node */
if (opt_.criterionType != CriterionType::CTC || !prevHyp.prevBlank ||
prevLex == lexicon_->getRoot()) {
int n = prevLex == lexicon_->getRoot() ? sil_ : prevIdx;
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + amScore;
if (n == sil_) {
score += opt_.silScore;
}
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
/* (3) CTC only, try blank */
if (opt_.criterionType == CriterionType::CTC) {
int n = blank_;
double amScore = emissions[t * N + n];
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + amScore,
prevHyp.lmState,
prevLex,
&prevHyp,
n,
-1,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
// finish proposing
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}
void LexiconDecoder::decodeEnd() {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
bool hasNiceEnding = false;
for (const LexiconDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
if (prevHyp.lex == lexicon_->getRoot()) {
hasNiceEnding = true;
break;
}
}
for (const LexiconDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
const TrieNode* prevLex = prevHyp.lex;
const LMStatePtr& prevLmState = prevHyp.lmState;
if (!hasNiceEnding || prevHyp.lex == lexicon_->getRoot()) {
auto lmStateScorePair = lm_->finish(prevLmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
prevLex,
&prevHyp,
sil_,
-1,
false, // prevBlank
prevHyp.amScore,
prevHyp.lmScore + lmScore);
}
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[nDecodedFrames_ - nPrunedFrames_ + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
++nDecodedFrames_;
}
std::vector<DecodeResult> LexiconDecoder::getAllFinalHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
if (finalFrame < 1) {
return std::vector<DecodeResult>{};
}
return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame);
}
DecodeResult LexiconDecoder::getBestHypothesis(int lookBack) const {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return DecodeResult();
}
const LexiconDecoderState* bestNode = findBestAncestor(
hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack);
return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack);
}
int LexiconDecoder::nHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return hyp_.find(finalFrame)->second.size();
}
int LexiconDecoder::nDecodedFramesInBuffer() const {
return nDecodedFrames_ - nPrunedFrames_ + 1;
}
void LexiconDecoder::prune(int lookBack) {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return; // Not enough decoded frames to prune
}
/* (1) Find the last emitted word in the best path */
const LexiconDecoderState* bestNode = findBestAncestor(
hyp_.find(nDecodedFrames_ - nPrunedFrames_)->second, lookBack);
if (!bestNode) {
return; // Not enough decoded frames to prune
}
int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack;
if (startFrame < 1) {
return; // Not enough decoded frames to prune
}
/* (2) Move things from back of hyp_ to front and normalize scores */
pruneAndNormalize(hyp_, startFrame, lookBack);
nPrunedFrames_ = nDecodedFrames_ - lookBack;
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <unordered_map>
#include "torchaudio/csrc/decoder/src/decoder/Decoder.h"
#include "torchaudio/csrc/decoder/src/decoder/Trie.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
struct LexiconDecoderOptions {
int beamSize; // Maximum number of hypothesis we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypothesis
double lmWeight; // Weight of lm
double wordScore; // Word insertion score
double unkScore; // Unknown word insertion score
double silScore; // Silence insertion score
bool logAdd; // If or not use logadd when merging hypothesis
CriterionType criterionType; // CTC or ASG
};
/**
* LexiconDecoderState stores information for each hypothesis in the beam.
*/
struct LexiconDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const TrieNode* lex; // Trie node in the lexicon
const LexiconDecoderState* parent; // Parent hypothesis
int token; // Label of token
int word; // Label of word (-1 if incomplete)
bool prevBlank; // If previous hypothesis is blank (for CTC only)
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconDecoderState(
const double score,
const LMStatePtr& lmState,
const TrieNode* lex,
const LexiconDecoderState* parent,
const int token,
const int word,
const bool prevBlank = false,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
lex(lex),
parent(parent),
token(token),
word(word),
prevBlank(prevBlank),
amScore(amScore),
lmScore(lmScore) {}
LexiconDecoderState()
: score(0.),
lmState(nullptr),
lex(nullptr),
parent(nullptr),
token(-1),
word(-1),
prevBlank(false),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (lex != node->lex) {
return lex > node->lex ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
} else if (prevBlank != node->prevBlank) {
return prevBlank > node->prevBlank ? 1 : -1;
}
return 0;
}
int getWord() const {
return word;
}
bool isComplete() const {
return !parent || parent->word >= 0;
}
};
/**
* Decoder implements a beam seach decoder that finds the word transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + wordScore_ * |W_known| + unkScore_ *
* |W_unknown| + silScore_ * |{i| pi_i = <sil>}|
*
* where P_{lm}(W) is the language model score, pi_i is the value for the i-th
* frame in the path leading to W and AM(W) is the (unnormalized) acoustic model
* score of the transcription W. Note that the lexicon is used to limit the
* search space and all candidate words are generated from it if unkScore is
* -inf, otherwise <UNK> will be generated for OOVs.
*/
class LexiconDecoder : public Decoder {
public:
LexiconDecoder(
LexiconDecoderOptions opt,
const TriePtr& lexicon,
const LMPtr& lm,
const int sil,
const int blank,
const int unk,
const std::vector<float>& transitions,
const bool isLmToken)
: opt_(std::move(opt)),
lexicon_(lexicon),
lm_(lm),
sil_(sil),
blank_(blank),
unk_(unk),
transitions_(transitions),
isLmToken_(isLmToken) {}
void decodeBegin() override;
void decodeStep(const float* emissions, int T, int N) override;
void decodeEnd() override;
int nHypothesis() const;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconDecoderOptions opt_;
// Lexicon trie to restrict beam-search decoder
TriePtr lexicon_;
LMPtr lm_;
// Index of silence label
int sil_;
// Index of blank label (for CTC)
int blank_;
// Index of unknown word
int unk_;
// matrix of transitions (for ASG criterion)
std::vector<float> transitions_;
// if LM is token-level (operates on the same level as acoustic model)
// or it is word-level (in case of false)
bool isLmToken_;
// All the hypothesis new candidates (can be larger than beamsize) proposed
// based on the ones from previous frame
std::vector<LexiconDecoderState> candidates_;
// This vector is designed for efficient sorting and merging the candidates_,
// so instead of moving around objects, we only need to sort pointers
std::vector<LexiconDecoderState*> candidatePtrs_;
// Best candidate score of current frame
double candidatesBestScore_;
// Vector of hypothesis for all the frames so far
std::unordered_map<int, std::vector<LexiconDecoderState>> hyp_;
// These 2 variables are used for online decoding, for hypothesis pruning
int nDecodedFrames_; // Total number of decoded frames.
int nPrunedFrames_; // Total number of pruned frames from hyp_.
};
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h"
namespace torchaudio {
namespace lib {
namespace text {
void LexiconFreeDecoder::decodeBegin() {
hyp_.clear();
hyp_.emplace(0, std::vector<LexiconFreeDecoderState>());
/* note: the lm reset itself with :start() */
hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, sil_);
nDecodedFrames_ = 0;
nPrunedFrames_ = 0;
}
void LexiconFreeDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconFreeDecoderState>());
}
}
std::vector<size_t> idx(N);
// Looping over all the frames
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp : hyp_[startFrame + t]) {
const int prevIdx = prevHyp.token;
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + emissions[t * N + n];
if (n == sil_) {
score += opt_.silScore;
}
if ((opt_.criterionType == CriterionType::ASG && n != prevIdx) ||
(opt_.criterionType == CriterionType::CTC && n != blank_ &&
(n != prevIdx || prevHyp.prevBlank))) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
} else if (opt_.criterionType == CriterionType::CTC && n == blank_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
} else {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
}
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}
void LexiconFreeDecoder::decodeEnd() {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
const LMStatePtr& prevLmState = prevHyp.lmState;
auto lmStateScorePair = lm_->finish(prevLmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
sil_,
false, // prevBlank
prevHyp.amScore,
prevHyp.lmScore + lmScore);
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[nDecodedFrames_ - nPrunedFrames_ + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
++nDecodedFrames_;
}
std::vector<DecodeResult> LexiconFreeDecoder::getAllFinalHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame);
}
DecodeResult LexiconFreeDecoder::getBestHypothesis(int lookBack) const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack);
}
int LexiconFreeDecoder::nHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return hyp_.find(finalFrame)->second.size();
}
int LexiconFreeDecoder::nDecodedFramesInBuffer() const {
return nDecodedFrames_ - nPrunedFrames_ + 1;
}
void LexiconFreeDecoder::prune(int lookBack) {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return; // Not enough decoded frames to prune
}
/* (1) Find the last emitted word in the best path */
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
if (!bestNode) {
return; // Not enough decoded frames to prune
}
int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack;
if (startFrame < 1) {
return; // Not enough decoded frames to prune
}
/* (2) Move things from back of hyp_ to front and normalize scores */
pruneAndNormalize(hyp_, startFrame, lookBack);
nPrunedFrames_ = nDecodedFrames_ - lookBack;
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <unordered_map>
#include "torchaudio/csrc/decoder/src/decoder/Decoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
struct LexiconFreeDecoderOptions {
int beamSize; // Maximum number of hypotheses we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypotheses
double lmWeight; // Weight of lm
double silScore; // Silence insertion score
bool logAdd;
CriterionType criterionType; // CTC or ASG
};
/**
* LexiconFreeDecoderState stores information for each hypothesis in the beam.
*/
struct LexiconFreeDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const LexiconFreeDecoderState* parent; // Parent hypothesis
int token; // Label of token
bool prevBlank; // If previous hypothesis is blank (for CTC only)
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconFreeDecoderState(
const double score,
const LMStatePtr& lmState,
const LexiconFreeDecoderState* parent,
const int token,
const bool prevBlank = false,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
parent(parent),
token(token),
prevBlank(prevBlank),
amScore(amScore),
lmScore(lmScore) {}
LexiconFreeDecoderState()
: score(0),
lmState(nullptr),
parent(nullptr),
token(-1),
prevBlank(false),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconFreeDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
} else if (prevBlank != node->prevBlank) {
return prevBlank > node->prevBlank ? 1 : -1;
}
return 0;
}
int getWord() const {
return -1;
}
bool isComplete() const {
return true;
}
};
/**
* Decoder implements a beam seach decoder that finds the word transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + silScore_ * |{i| pi_i = <sil>}|
*
* where P_{lm}(W) is the language model score, pi_i is the value for the i-th
* frame in the path leading to W and AM(W) is the (unnormalized) acoustic model
* score of the transcription W. We are allowed to generate words from all the
* possible combinations of tokens.
*/
class LexiconFreeDecoder : public Decoder {
public:
LexiconFreeDecoder(
LexiconFreeDecoderOptions opt,
const LMPtr& lm,
const int sil,
const int blank,
const std::vector<float>& transitions)
: opt_(std::move(opt)),
lm_(lm),
transitions_(transitions),
sil_(sil),
blank_(blank) {}
void decodeBegin() override;
void decodeStep(const float* emissions, int T, int N) override;
void decodeEnd() override;
int nHypothesis() const;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconFreeDecoderOptions opt_;
LMPtr lm_;
std::vector<float> transitions_;
// All the hypotheses new candidates (can be larger than beamsize) proposed
// based on the ones from previous frame
std::vector<LexiconFreeDecoderState> candidates_;
// This vector is designed for efficient sorting and merging the candidates_,
// so instead of moving around objects, we only need to sort pointers
std::vector<LexiconFreeDecoderState*> candidatePtrs_;
// Best candidate score of current frame
double candidatesBestScore_;
// Index of silence label
int sil_;
// Index of blank label (for CTC)
int blank_;
// Vector of hypothesis for all the frames so far
std::unordered_map<int, std::vector<LexiconFreeDecoderState>> hyp_;
// These 2 variables are used for online decoding, for hypothesis pruning
int nDecodedFrames_; // Total number of decoded frames.
int nPrunedFrames_; // Total number of pruned frames from hyp_.
};
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <math.h>
#include <stdlib.h>
#include <iostream>
#include <limits>
#include "torchaudio/csrc/decoder/src/decoder/Trie.h"
namespace torchaudio {
namespace lib {
namespace text {
const double kMinusLogThreshold = -39.14;
const TrieNode* Trie::getRoot() const {
return root_.get();
}
TrieNodePtr Trie::insert(
const std::vector<int>& indices,
int label,
float score) {
TrieNodePtr node = root_;
for (int i = 0; i < indices.size(); i++) {
int idx = indices[i];
if (idx < 0 || idx >= maxChildren_) {
throw std::out_of_range(
"[Trie] Invalid letter index: " + std::to_string(idx));
}
if (node->children.find(idx) == node->children.end()) {
node->children[idx] = std::make_shared<TrieNode>(idx);
}
node = node->children[idx];
}
if (node->labels.size() < kTrieMaxLabel) {
node->labels.push_back(label);
node->scores.push_back(score);
} else {
std::cerr << "[Trie] Trie label number reached limit: " << kTrieMaxLabel
<< "\n";
}
return node;
}
TrieNodePtr Trie::search(const std::vector<int>& indices) {
TrieNodePtr node = root_;
for (auto idx : indices) {
if (idx < 0 || idx >= maxChildren_) {
throw std::out_of_range(
"[Trie] Invalid letter index: " + std::to_string(idx));
}
if (node->children.find(idx) == node->children.end()) {
return nullptr;
}
node = node->children[idx];
}
return node;
}
/* logadd */
double TrieLogAdd(double log_a, double log_b) {
double minusdif;
if (log_a < log_b) {
std::swap(log_a, log_b);
}
minusdif = log_b - log_a;
if (minusdif < kMinusLogThreshold) {
return log_a;
} else {
return log_a + log1p(exp(minusdif));
}
}
void smearNode(TrieNodePtr node, SmearingMode smearMode) {
node->maxScore = -std::numeric_limits<float>::infinity();
for (auto score : node->scores) {
node->maxScore = TrieLogAdd(node->maxScore, score);
}
for (auto child : node->children) {
auto childNode = child.second;
smearNode(childNode, smearMode);
if (smearMode == SmearingMode::LOGADD) {
node->maxScore = TrieLogAdd(node->maxScore, childNode->maxScore);
} else if (
smearMode == SmearingMode::MAX &&
childNode->maxScore > node->maxScore) {
node->maxScore = childNode->maxScore;
}
}
}
void Trie::smear(SmearingMode smearMode) {
if (smearMode != SmearingMode::NONE) {
smearNode(root_, smearMode);
}
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <memory>
#include <unordered_map>
#include <vector>
namespace torchaudio {
namespace lib {
namespace text {
constexpr int kTrieMaxLabel = 6;
enum class SmearingMode {
NONE = 0,
MAX = 1,
LOGADD = 2,
};
/**
* TrieNode is the trie node structure in Trie.
*/
struct TrieNode {
explicit TrieNode(int idx)
: children(std::unordered_map<int, std::shared_ptr<TrieNode>>()),
idx(idx),
maxScore(0) {
labels.reserve(kTrieMaxLabel);
scores.reserve(kTrieMaxLabel);
}
// Pointers to the children of a node
std::unordered_map<int, std::shared_ptr<TrieNode>> children;
// Node index
int idx;
// Labels of words that are constructed from the given path. Note that
// `labels` is nonempty only if the current node represents a completed token.
std::vector<int> labels;
// Scores (`scores` should have the same size as `labels`)
std::vector<float> scores;
// Maximum score of all the labels if this node is a leaf,
// otherwise it will be the value after trie smearing.
float maxScore;
};
using TrieNodePtr = std::shared_ptr<TrieNode>;
/**
* Trie is used to store the lexicon in langiage model. We use it to limit
* the search space in deocder and quickly look up scores for a given token
* (completed word) or make prediction for incompleted ones based on smearing.
*/
class Trie {
public:
Trie(int maxChildren, int rootIdx)
: root_(std::make_shared<TrieNode>(rootIdx)), maxChildren_(maxChildren) {}
/* Return the root node pointer */
const TrieNode* getRoot() const;
/* Insert a token into trie with label */
TrieNodePtr insert(const std::vector<int>& indices, int label, float score);
/* Get the labels for a given token */
TrieNodePtr search(const std::vector<int>& indices);
/**
* Smearing the trie using the valid labels inserted in the trie so as to get
* score on each node (incompleted token).
* For example, if smear_mode is MAX, then for node "a" in path "c"->"a", we
* will select the maximum score from all its children like "c"->"a"->"t",
* "c"->"a"->"n", "c"->"a"->"r"->"e" and so on.
* This process will be carry out recusively on all the nodes.
*/
void smear(const SmearingMode smear_mode);
private:
TrieNodePtr root_;
int maxChildren_; // The maximum number of childern for each node. It is
// usually the size of letters or phonmes.
};
using TriePtr = std::shared_ptr<Trie>;
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
namespace torchaudio {
namespace lib {
namespace text {
// Place holder
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <algorithm>
#include <cmath>
#include <limits>
#include <unordered_map>
#include <vector>
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
/* ===================== Definitions ===================== */
const double kNegativeInfinity = -std::numeric_limits<double>::infinity();
const int kLookBackLimit = 100;
struct DecodeResult {
double score;
double amScore;
double lmScore;
std::vector<int> words;
std::vector<int> tokens;
explicit DecodeResult(int length = 0)
: score(0), words(length, -1), tokens(length, -1) {}
};
/* ===================== Candidate-related operations ===================== */
template <class DecoderState>
void candidatesReset(
double& candidatesBestScore,
std::vector<DecoderState>& candidates,
std::vector<DecoderState*>& candidatePtrs) {
candidatesBestScore = kNegativeInfinity;
candidates.clear();
candidatePtrs.clear();
}
template <class DecoderState, class... Args>
void candidatesAdd(
std::vector<DecoderState>& candidates,
double& candidatesBestScore,
const double beamThreshold,
const double score,
const Args&... args) {
if (score >= candidatesBestScore) {
candidatesBestScore = score;
}
if (score >= candidatesBestScore - beamThreshold) {
candidates.emplace_back(score, args...);
}
}
template <class DecoderState>
void candidatesStore(
std::vector<DecoderState>& candidates,
std::vector<DecoderState*>& candidatePtrs,
std::vector<DecoderState>& outputs,
const int beamSize,
const double threshold,
const bool logAdd,
const bool returnSorted) {
outputs.clear();
if (candidates.empty()) {
return;
}
/* 1. Select valid candidates */
for (auto& candidate : candidates) {
if (candidate.score >= threshold) {
candidatePtrs.emplace_back(&candidate);
}
}
/* 2. Merge candidates */
std::sort(
candidatePtrs.begin(),
candidatePtrs.end(),
[](const DecoderState* node1, const DecoderState* node2) {
int cmp = node1->compareNoScoreStates(node2);
return cmp == 0 ? node1->score > node2->score : cmp > 0;
});
int nHypAfterMerging = 1;
for (int i = 1; i < candidatePtrs.size(); i++) {
if (candidatePtrs[i]->compareNoScoreStates(
candidatePtrs[nHypAfterMerging - 1]) != 0) {
// Distinct candidate
candidatePtrs[nHypAfterMerging] = candidatePtrs[i];
nHypAfterMerging++;
} else {
// Same candidate
double maxScore = std::max(
candidatePtrs[nHypAfterMerging - 1]->score, candidatePtrs[i]->score);
if (logAdd) {
double minScore = std::min(
candidatePtrs[nHypAfterMerging - 1]->score,
candidatePtrs[i]->score);
candidatePtrs[nHypAfterMerging - 1]->score =
maxScore + std::log1p(std::exp(minScore - maxScore));
} else {
candidatePtrs[nHypAfterMerging - 1]->score = maxScore;
}
}
}
candidatePtrs.resize(nHypAfterMerging);
/* 3. Sort and prune */
auto compareNodeScore = [](const DecoderState* node1,
const DecoderState* node2) {
return node1->score > node2->score;
};
int nValidHyp = candidatePtrs.size();
int finalSize = std::min(nValidHyp, beamSize);
if (!returnSorted && nValidHyp > beamSize) {
std::nth_element(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
} else if (returnSorted) {
std::partial_sort(
candidatePtrs.begin(),
candidatePtrs.begin() + finalSize,
candidatePtrs.begin() + nValidHyp,
compareNodeScore);
}
for (int i = 0; i < finalSize; i++) {
outputs.emplace_back(std::move(*candidatePtrs[i]));
}
}
/* ===================== Result-related operations ===================== */
template <class DecoderState>
DecodeResult getHypothesis(const DecoderState* node, const int finalFrame) {
const DecoderState* node_ = node;
if (!node_) {
return DecodeResult();
}
DecodeResult res(finalFrame + 1);
res.score = node_->score;
res.amScore = node_->amScore;
res.lmScore = node_->lmScore;
int i = 0;
while (node_) {
res.words[finalFrame - i] = node_->getWord();
res.tokens[finalFrame - i] = node_->token;
node_ = node_->parent;
i++;
}
return res;
}
template <class DecoderState>
std::vector<DecodeResult> getAllHypothesis(
const std::vector<DecoderState>& finalHyps,
const int finalFrame) {
int nHyp = finalHyps.size();
std::vector<DecodeResult> res(nHyp);
for (int r = 0; r < nHyp; r++) {
const DecoderState* node = &finalHyps[r];
res[r] = getHypothesis(node, finalFrame);
}
return res;
}
template <class DecoderState>
const DecoderState* findBestAncestor(
const std::vector<DecoderState>& finalHyps,
int& lookBack) {
int nHyp = finalHyps.size();
if (nHyp == 0) {
return nullptr;
}
double bestScore = finalHyps.front().score;
const DecoderState* bestNode = finalHyps.data();
for (int r = 1; r < nHyp; r++) {
const DecoderState* node = &finalHyps[r];
if (node->score > bestScore) {
bestScore = node->score;
bestNode = node;
}
}
int n = 0;
while (bestNode && n < lookBack) {
n++;
bestNode = bestNode->parent;
}
const int maxLookBack = lookBack + kLookBackLimit;
while (bestNode) {
// Check for first emitted word.
if (bestNode->isComplete()) {
break;
}
n++;
bestNode = bestNode->parent;
if (n == maxLookBack) {
break;
}
}
lookBack = n;
return bestNode;
}
template <class DecoderState>
void pruneAndNormalize(
std::unordered_map<int, std::vector<DecoderState>>& hypothesis,
const int startFrame,
const int lookBack) {
/* 1. Move things from back of hypothesis to front. */
for (int i = 0; i < hypothesis.size(); i++) {
if (i <= lookBack) {
hypothesis[i].swap(hypothesis[i + startFrame]);
} else {
hypothesis[i].clear();
}
}
/* 2. Avoid further back-tracking */
for (DecoderState& hyp : hypothesis[0]) {
hyp.parent = nullptr;
}
/* 3. Avoid score underflow/overflow. */
double largestScore = hypothesis[lookBack].front().score;
for (int i = 1; i < hypothesis[lookBack].size(); i++) {
if (largestScore < hypothesis[lookBack][i].score) {
largestScore = hypothesis[lookBack][i].score;
}
}
for (int i = 0; i < hypothesis[lookBack].size(); i++) {
hypothesis[lookBack][i].score -= largestScore;
}
}
/* ===================== LM-related operations ===================== */
template <class DecoderState>
void updateLMCache(const LMPtr& lm, std::vector<DecoderState>& hypothesis) {
// For ConvLM update cache
std::vector<LMStatePtr> states;
for (const auto& hyp : hypothesis) {
states.emplace_back(hyp.lmState);
}
lm->updateCache(states);
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
#include <stdexcept>
#ifdef USE_KENLM_FROM_LANGTECH
#include "language_technology/jedi/lm/model.hh"
#else
#include "lm/model.hh"
#endif
namespace torchaudio {
namespace lib {
namespace text {
KenLMState::KenLMState() : ken_(std::make_unique<lm::ngram::State>()) {}
KenLM::KenLM(const std::string& path, const Dictionary& usrTknDict) {
// Load LM
model_.reset(lm::ngram::LoadVirtual(path.c_str()));
if (!model_) {
throw std::runtime_error("[KenLM] LM loading failed.");
}
vocab_ = &model_->BaseVocabulary();
if (!vocab_) {
throw std::runtime_error("[KenLM] LM vocabulary loading failed.");
}
// Create index map
usrToLmIdxMap_.resize(usrTknDict.indexSize());
for (int i = 0; i < usrTknDict.indexSize(); i++) {
auto token = usrTknDict.getEntry(i);
int lmIdx = vocab_->Index(token.c_str());
usrToLmIdxMap_[i] = lmIdx;
}
}
LMStatePtr KenLM::start(bool startWithNothing) {
auto outState = std::make_shared<KenLMState>();
if (startWithNothing) {
model_->NullContextWrite(outState->ken());
} else {
model_->BeginSentenceWrite(outState->ken());
}
return outState;
}
std::pair<LMStatePtr, float> KenLM::score(
const LMStatePtr& state,
const int usrTokenIdx) {
if (usrTokenIdx < 0 || usrTokenIdx >= usrToLmIdxMap_.size()) {
throw std::runtime_error(
"[KenLM] Invalid user token index: " + std::to_string(usrTokenIdx));
}
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(usrTokenIdx);
float score = model_->BaseScore(
inState->ken(), usrToLmIdxMap_[usrTokenIdx], outState->ken());
return std::make_pair(std::move(outState), score);
}
std::pair<LMStatePtr, float> KenLM::finish(const LMStatePtr& state) {
auto inState = std::static_pointer_cast<KenLMState>(state);
auto outState = inState->child<KenLMState>(-1);
float score =
model_->BaseScore(inState->ken(), vocab_->EndSentence(), outState->ken());
return std::make_pair(std::move(outState), score);
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <memory>
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
// Forward declarations to avoid including KenLM headers
namespace lm {
namespace base {
struct Vocabulary;
struct Model;
} // namespace base
namespace ngram {
struct State;
} // namespace ngram
} // namespace lm
namespace torchaudio {
namespace lib {
namespace text {
/**
* KenLMState is a state object from KenLM, which contains context length,
* indicies and compare functions
* https://github.com/kpu/kenlm/blob/master/lm/state.hh.
*/
struct KenLMState : LMState {
KenLMState();
std::unique_ptr<lm::ngram::State> ken_;
lm::ngram::State* ken() {
return ken_.get();
}
};
/**
* KenLM extends LM by using the toolkit https://kheafield.com/code/kenlm/.
*/
class KenLM : public LM {
public:
KenLM(const std::string& path, const Dictionary& usrTknDict);
LMStatePtr start(bool startWithNothing) override;
std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) override;
std::pair<LMStatePtr, float> finish(const LMStatePtr& state) override;
private:
std::shared_ptr<lm::base::Model> model_;
const lm::base::Vocabulary* vocab_;
};
using KenLMPtr = std::shared_ptr<KenLM>;
} // namespace text
} // namespace lib
} // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment