Commit 97ed428d authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add lexicon free CTC decoder (#2342)

Summary:
Add support for lexicon free decoding based on [fairseq's](https://github.com/pytorch/fairseq/blob/main/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53) implementation. Reached numerical parity with fairseq's decoder in offline experimentation

Follow ups
- Add pretrained LM support for lex free decoding
- Add example in tutorial
- Replace flashlight C++ source code with flashlight text submodule
- [optional] fairseq compatibility test

Pull Request resolved: https://github.com/pytorch/audio/pull/2342

Reviewed By: nateanl

Differential Revision: D35856104

Pulled By: carolineechen

fbshipit-source-id: b64286550984df906ebb747e82f6fb1f21948ac7
parent 7c249d17
...@@ -6,11 +6,11 @@ torchaudio.prototype.ctc_decoder ...@@ -6,11 +6,11 @@ torchaudio.prototype.ctc_decoder
Decoder Class Decoder Class
------------- -------------
LexiconDecoder CTCDecoder
~~~~~~~~~~~~~~ ~~~~~~~~~~
.. autoclass:: LexiconDecoder .. autoclass:: CTCDecoder
.. automethod:: __call__ .. automethod:: __call__
...@@ -24,10 +24,10 @@ Hypothesis ...@@ -24,10 +24,10 @@ Hypothesis
Factory Function Factory Function
---------------- ----------------
lexicon_decoder ctc_decoder
~~~~~~~~~~~~~~~ ~~~~~~~~~~~
.. autoclass:: lexicon_decoder .. autoclass:: ctc_decoder
Utility Function Utility Function
---------------- ----------------
......
...@@ -233,7 +233,7 @@ print(files) ...@@ -233,7 +233,7 @@ print(files)
# Beam Search Decoder # Beam Search Decoder
# ~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~
# The decoder can be constructed using the factory function # The decoder can be constructed using the factory function
# :py:func:`lexicon_decoder <torchaudio.prototype.ctc_decoder.lexicon_decoder>`. # :py:func:`ctc_decoder <torchaudio.prototype.ctc_decoder.ctc_decoder>`.
# In addition to the previously mentioned components, it also takes in various beam # In addition to the previously mentioned components, it also takes in various beam
# search decoding parameters and token/word parameters. # search decoding parameters and token/word parameters.
# #
...@@ -241,12 +241,12 @@ print(files) ...@@ -241,12 +241,12 @@ print(files)
# `lm` parameter. # `lm` parameter.
# #
from torchaudio.prototype.ctc_decoder import lexicon_decoder from torchaudio.prototype.ctc_decoder import ctc_decoder
LM_WEIGHT = 3.23 LM_WEIGHT = 3.23
WORD_SCORE = -0.26 WORD_SCORE = -0.26
beam_search_decoder = lexicon_decoder( beam_search_decoder = ctc_decoder(
lexicon=files.lexicon, lexicon=files.lexicon,
tokens=files.tokens, tokens=files.tokens,
lm=files.lm, lm=files.lm,
...@@ -395,7 +395,7 @@ plot_alignments(waveform[0], emission, predicted_tokens, timesteps) ...@@ -395,7 +395,7 @@ plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
# In this section, we go a little bit more in depth about some different # In this section, we go a little bit more in depth about some different
# parameters and tradeoffs. For the full list of customizable parameters, # parameters and tradeoffs. For the full list of customizable parameters,
# please refer to the # please refer to the
# :py:func:`documentation <torchaudio.prototype.ctc_decoder.lexicon_decoder>`. # :py:func:`documentation <torchaudio.prototype.ctc_decoder.ctc_decoder>`.
# #
...@@ -450,7 +450,7 @@ for i in range(3): ...@@ -450,7 +450,7 @@ for i in range(3):
beam_sizes = [1, 5, 50, 500] beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes: for beam_size in beam_sizes:
beam_search_decoder = lexicon_decoder( beam_search_decoder = ctc_decoder(
lexicon=files.lexicon, lexicon=files.lexicon,
tokens=files.tokens, tokens=files.tokens,
lm=files.lm, lm=files.lm,
...@@ -476,7 +476,7 @@ num_tokens = len(tokens) ...@@ -476,7 +476,7 @@ num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens] beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens: for beam_size_token in beam_size_tokens:
beam_search_decoder = lexicon_decoder( beam_search_decoder = ctc_decoder(
lexicon=files.lexicon, lexicon=files.lexicon,
tokens=files.tokens, tokens=files.tokens,
lm=files.lm, lm=files.lm,
...@@ -503,7 +503,7 @@ for beam_size_token in beam_size_tokens: ...@@ -503,7 +503,7 @@ for beam_size_token in beam_size_tokens:
beam_thresholds = [1, 5, 10, 25] beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds: for beam_threshold in beam_thresholds:
beam_search_decoder = lexicon_decoder( beam_search_decoder = ctc_decoder(
lexicon=files.lexicon, lexicon=files.lexicon,
tokens=files.tokens, tokens=files.tokens,
lm=files.lm, lm=files.lm,
...@@ -529,7 +529,7 @@ for beam_threshold in beam_thresholds: ...@@ -529,7 +529,7 @@ for beam_threshold in beam_thresholds:
lm_weights = [0, LM_WEIGHT, 15] lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights: for lm_weight in lm_weights:
beam_search_decoder = lexicon_decoder( beam_search_decoder = ctc_decoder(
lexicon=files.lexicon, lexicon=files.lexicon,
tokens=files.tokens, tokens=files.tokens,
lm=files.lm, lm=files.lm,
......
...@@ -9,10 +9,10 @@ import pytest ...@@ -9,10 +9,10 @@ import pytest
], ],
) )
def test_decoder_from_pretrained(model, expected, emissions): def test_decoder_from_pretrained(model, expected, emissions):
from torchaudio.prototype.ctc_decoder import lexicon_decoder, download_pretrained_files from torchaudio.prototype.ctc_decoder import ctc_decoder, download_pretrained_files
pretrained_files = download_pretrained_files(model) pretrained_files = download_pretrained_files(model)
decoder = lexicon_decoder( decoder = ctc_decoder(
lexicon=pretrained_files.lexicon, lexicon=pretrained_files.lexicon,
tokens=pretrained_files.tokens, tokens=pretrained_files.tokens,
lm=pretrained_files.lm, lm=pretrained_files.lm,
......
\data\
ngram 1=8
ngram 2=8
ngram 3=8
ngram 4=8
ngram 5=8
\1-grams:
-1.146128 <unk> 0
0 <s> -0.30103
-0.8731268 </s> 0
-0.70679533 f -0.30103
-0.70679533 o -0.30103
-0.8731268 b -0.30103
-0.8731268 a -0.30103
-0.8731268 r -0.30103
\2-grams:
-0.24644431 r </s> 0
-0.22314323 <s> f -0.30103
-0.57694924 o f -0.30103
-0.22314323 f o -0.30103
-0.57694924 o o -0.30103
-0.6314696 o b -0.30103
-0.24644431 b a -0.30103
-0.24644431 a r -0.30103
\3-grams:
-0.105970904 a r </s> 0
-0.41743615 o o f -0.30103
-0.097394995 <s> f o -0.30103
-0.097394995 o f o -0.30103
-0.19898036 f o o -0.30103
-0.43555236 o o b -0.30103
-0.105970904 o b a -0.30103
-0.105970904 b a r -0.30103
\4-grams:
-0.049761247 b a r </s> 0
-0.4462542 f o o f -0.30103
-0.045972984 o o f o -0.30103
-0.08819265 <s> f o o -0.30103
-0.08819265 o f o o -0.30103
-0.286727 f o o b -0.30103
-0.049761247 o o b a -0.30103
-0.049761247 o b a r -0.30103
\5-grams:
-0.02416831 o b a r </s>
-0.36759996 <s> f o o f
-0.022378458 f o o f o
-0.041861475 o o f o o
-0.29381964 <s> f o o b
-0.12011856 o f o o b
-0.02416831 f o o b a
-0.02416831 o o b a r
\end\
...@@ -10,18 +10,25 @@ from torchaudio_unittest.common_utils import ( ...@@ -10,18 +10,25 @@ from torchaudio_unittest.common_utils import (
) )
NUM_TOKENS = 8
@skipIfNoCtcDecoder @skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, use_lm=True, **kwargs): def _get_decoder(self, tokens=None, use_lm=True, use_lexicon=True, **kwargs):
from torchaudio.prototype.ctc_decoder import lexicon_decoder from torchaudio.prototype.ctc_decoder import ctc_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt") if use_lexicon:
kenlm_file = get_asset_path("decoder/kenlm.arpa") if use_lm else None lexicon_file = get_asset_path("decoder/lexicon.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa") if use_lm else None
else:
lexicon_file = None
kenlm_file = get_asset_path("decoder/kenlm_char.arpa") if use_lm else None
if tokens is None: if tokens is None:
tokens = get_asset_path("decoder/tokens.txt") tokens = get_asset_path("decoder/tokens.txt")
return lexicon_decoder( return ctc_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
lm=kenlm_file, lm=kenlm_file,
...@@ -29,7 +36,7 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -29,7 +36,7 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
) )
def _get_emissions(self): def _get_emissions(self):
B, T, N = 4, 15, 10 B, T, N = 4, 15, NUM_TOKENS
torch.manual_seed(0) torch.manual_seed(0)
emissions = torch.rand(B, T, N) emissions = torch.rand(B, T, N)
...@@ -41,39 +48,46 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -41,39 +48,46 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
itertools.product( itertools.product(
[get_asset_path("decoder/tokens.txt"), ["-", "|", "f", "o", "b", "a", "r"]], [get_asset_path("decoder/tokens.txt"), ["-", "|", "f", "o", "b", "a", "r"]],
[True, False], [True, False],
[True, False],
) )
), ),
) )
def test_construct_decoder(self, tokens, use_lm): def test_construct_decoder(self, tokens, use_lm, use_lexicon):
self._get_decoder(tokens=tokens, use_lm=use_lm) self._get_decoder(tokens=tokens, use_lm=use_lm, use_lexicon=use_lexicon)
def test_no_lm_decoder(self):
"""Check that using no LM produces the same result as using an LM with 0 lm_weight"""
kenlm_decoder = self._get_decoder(lm_weight=0)
zerolm_decoder = self._get_decoder(use_lm=False)
@parameterized.expand(
[(True,), (False,)],
)
def test_shape(self, use_lexicon):
emissions = self._get_emissions() emissions = self._get_emissions()
kenlm_results = kenlm_decoder(emissions) decoder = self._get_decoder(use_lexicon=use_lexicon)
zerolm_results = zerolm_decoder(emissions)
self.assertEqual(kenlm_results, zerolm_results)
def test_shape(self):
emissions = self._get_emissions()
decoder = self._get_decoder()
results = decoder(emissions) results = decoder(emissions)
self.assertEqual(len(results), emissions.shape[0]) self.assertEqual(len(results), emissions.shape[0])
def test_timesteps_shape(self): @parameterized.expand(
[(True,), (False,)],
)
def test_timesteps_shape(self, use_lexicon):
"""Each token should correspond with a timestep""" """Each token should correspond with a timestep"""
emissions = self._get_emissions() emissions = self._get_emissions()
decoder = self._get_decoder() decoder = self._get_decoder(use_lexicon=use_lexicon)
results = decoder(emissions) results = decoder(emissions)
for i in range(emissions.shape[0]): for i in range(emissions.shape[0]):
result = results[i][0] result = results[i][0]
self.assertEqual(result.tokens.shape, result.timesteps.shape) self.assertEqual(result.tokens.shape, result.timesteps.shape)
def test_no_lm_decoder(self):
"""Check that using no LM produces the same result as using an LM with 0 lm_weight"""
kenlm_decoder = self._get_decoder(lm_weight=0)
zerolm_decoder = self._get_decoder(use_lm=False)
emissions = self._get_emissions()
kenlm_results = kenlm_decoder(emissions)
zerolm_results = zerolm_decoder(emissions)
self.assertEqual(kenlm_results, zerolm_results)
def test_get_timesteps(self): def test_get_timesteps(self):
unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3]) unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3])
decoder = self._get_decoder() decoder = self._get_decoder()
......
...@@ -137,6 +137,7 @@ if (BUILD_CTC_DECODER) ...@@ -137,6 +137,7 @@ if (BUILD_CTC_DECODER)
set( set(
LIBTORCHAUDIO_DECODER_SOURCES LIBTORCHAUDIO_DECODER_SOURCES
decoder/src/decoder/LexiconDecoder.cpp decoder/src/decoder/LexiconDecoder.cpp
decoder/src/decoder/LexiconFreeDecoder.cpp
decoder/src/decoder/Trie.cpp decoder/src/decoder/Trie.cpp
decoder/src/decoder/Utils.cpp decoder/src/decoder/Utils.cpp
decoder/src/decoder/lm/KenLM.cpp decoder/src/decoder/lm/KenLM.cpp
......
#include <torch/extension.h> #include <torch/extension.h>
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h" #include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h" #include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h" #include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h"
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h" #include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
...@@ -103,6 +104,22 @@ std::vector<DecodeResult> LexiconDecoder_decode( ...@@ -103,6 +104,22 @@ std::vector<DecodeResult> LexiconDecoder_decode(
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N); return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
} }
void LexiconFreeDecoder_decodeStep(
LexiconFreeDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
decoder.decodeStep(reinterpret_cast<const float*>(emissions), T, N);
}
std::vector<DecodeResult> LexiconFreeDecoder_decode(
LexiconFreeDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}
void Dictionary_addEntry_0( void Dictionary_addEntry_0(
Dictionary& dict, Dictionary& dict,
const std::string& entry, const std::string& entry,
...@@ -191,6 +208,34 @@ PYBIND11_MODULE(_torchaudio_decoder, m) { ...@@ -191,6 +208,34 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
.def_readwrite("log_add", &LexiconDecoderOptions::logAdd) .def_readwrite("log_add", &LexiconDecoderOptions::logAdd)
.def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType); .def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType);
py::class_<LexiconFreeDecoderOptions>(m, "_LexiconFreeDecoderOptions")
.def(
py::init<
const int,
const int,
const double,
const double,
const double,
const bool,
const CriterionType>(),
"beam_size"_a,
"beam_size_token"_a,
"beam_threshold"_a,
"lm_weight"_a,
"sil_score"_a,
"log_add"_a,
"criterion_type"_a)
.def_readwrite("beam_size", &LexiconFreeDecoderOptions::beamSize)
.def_readwrite(
"beam_size_token", &LexiconFreeDecoderOptions::beamSizeToken)
.def_readwrite(
"beam_threshold", &LexiconFreeDecoderOptions::beamThreshold)
.def_readwrite("lm_weight", &LexiconFreeDecoderOptions::lmWeight)
.def_readwrite("sil_score", &LexiconFreeDecoderOptions::silScore)
.def_readwrite("log_add", &LexiconFreeDecoderOptions::logAdd)
.def_readwrite(
"criterion_type", &LexiconFreeDecoderOptions::criterionType);
py::class_<DecodeResult>(m, "_DecodeResult") py::class_<DecodeResult>(m, "_DecodeResult")
.def(py::init<int>(), "length"_a) .def(py::init<int>(), "length"_a)
.def_readwrite("score", &DecodeResult::score) .def_readwrite("score", &DecodeResult::score)
...@@ -226,6 +271,31 @@ PYBIND11_MODULE(_torchaudio_decoder, m) { ...@@ -226,6 +271,31 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
"look_back"_a = 0) "look_back"_a = 0)
.def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis); .def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis);
py::class_<LexiconFreeDecoder>(m, "_LexiconFreeDecoder")
.def(py::init<
LexiconFreeDecoderOptions,
const LMPtr,
const int,
const int,
const std::vector<float>&>())
.def("decode_begin", &LexiconFreeDecoder::decodeBegin)
.def(
"decode_step",
&LexiconFreeDecoder_decodeStep,
"emissions"_a,
"T"_a,
"N"_a)
.def("decode_end", &LexiconFreeDecoder::decodeEnd)
.def("decode", &LexiconFreeDecoder_decode, "emissions"_a, "T"_a, "N"_a)
.def("prune", &LexiconFreeDecoder::prune, "look_back"_a = 0)
.def(
"get_best_hypothesis",
&LexiconFreeDecoder::getBestHypothesis,
"look_back"_a = 0)
.def(
"get_all_final_hypothesis",
&LexiconFreeDecoder::getAllFinalHypothesis);
py::class_<Dictionary>(m, "_Dictionary") py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>()) .def(py::init<>())
.def(py::init<const std::vector<std::string>&>(), "tkns"_a) .def(py::init<const std::vector<std::string>&>(), "tkns"_a)
......
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h"
namespace torchaudio {
namespace lib {
namespace text {
void LexiconFreeDecoder::decodeBegin() {
hyp_.clear();
hyp_.emplace(0, std::vector<LexiconFreeDecoderState>());
/* note: the lm reset itself with :start() */
hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, sil_);
nDecodedFrames_ = 0;
nPrunedFrames_ = 0;
}
void LexiconFreeDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconFreeDecoderState>());
}
}
std::vector<size_t> idx(N);
// Looping over all the frames
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp : hyp_[startFrame + t]) {
const int prevIdx = prevHyp.token;
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + emissions[t * N + n];
if (n == sil_) {
score += opt_.silScore;
}
if ((opt_.criterionType == CriterionType::ASG && n != prevIdx) ||
(opt_.criterionType == CriterionType::CTC && n != blank_ &&
(n != prevIdx || prevHyp.prevBlank))) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
} else if (opt_.criterionType == CriterionType::CTC && n == blank_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
} else {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
}
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}
void LexiconFreeDecoder::decodeEnd() {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
const LMStatePtr& prevLmState = prevHyp.lmState;
auto lmStateScorePair = lm_->finish(prevLmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
sil_,
false, // prevBlank
prevHyp.amScore,
prevHyp.lmScore + lmScore);
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[nDecodedFrames_ - nPrunedFrames_ + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
++nDecodedFrames_;
}
std::vector<DecodeResult> LexiconFreeDecoder::getAllFinalHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame);
}
DecodeResult LexiconFreeDecoder::getBestHypothesis(int lookBack) const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack);
}
int LexiconFreeDecoder::nHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return hyp_.find(finalFrame)->second.size();
}
int LexiconFreeDecoder::nDecodedFramesInBuffer() const {
return nDecodedFrames_ - nPrunedFrames_ + 1;
}
void LexiconFreeDecoder::prune(int lookBack) {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return; // Not enough decoded frames to prune
}
/* (1) Find the last emitted word in the best path */
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
if (!bestNode) {
return; // Not enough decoded frames to prune
}
int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack;
if (startFrame < 1) {
return; // Not enough decoded frames to prune
}
/* (2) Move things from back of hyp_ to front and normalize scores */
pruneAndNormalize(hyp_, startFrame, lookBack);
nPrunedFrames_ = nDecodedFrames_ - lookBack;
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <unordered_map>
#include "torchaudio/csrc/decoder/src/decoder/Decoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
struct LexiconFreeDecoderOptions {
int beamSize; // Maximum number of hypotheses we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypotheses
double lmWeight; // Weight of lm
double silScore; // Silence insertion score
bool logAdd;
CriterionType criterionType; // CTC or ASG
};
/**
* LexiconFreeDecoderState stores information for each hypothesis in the beam.
*/
struct LexiconFreeDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const LexiconFreeDecoderState* parent; // Parent hypothesis
int token; // Label of token
bool prevBlank; // If previous hypothesis is blank (for CTC only)
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconFreeDecoderState(
const double score,
const LMStatePtr& lmState,
const LexiconFreeDecoderState* parent,
const int token,
const bool prevBlank = false,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
parent(parent),
token(token),
prevBlank(prevBlank),
amScore(amScore),
lmScore(lmScore) {}
LexiconFreeDecoderState()
: score(0),
lmState(nullptr),
parent(nullptr),
token(-1),
prevBlank(false),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconFreeDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
} else if (prevBlank != node->prevBlank) {
return prevBlank > node->prevBlank ? 1 : -1;
}
return 0;
}
int getWord() const {
return -1;
}
bool isComplete() const {
return true;
}
};
/**
* Decoder implements a beam seach decoder that finds the word transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + silScore_ * |{i| pi_i = <sil>}|
*
* where P_{lm}(W) is the language model score, pi_i is the value for the i-th
* frame in the path leading to W and AM(W) is the (unnormalized) acoustic model
* score of the transcription W. We are allowed to generate words from all the
* possible combinations of tokens.
*/
class LexiconFreeDecoder : public Decoder {
public:
LexiconFreeDecoder(
LexiconFreeDecoderOptions opt,
const LMPtr& lm,
const int sil,
const int blank,
const std::vector<float>& transitions)
: opt_(std::move(opt)),
lm_(lm),
transitions_(transitions),
sil_(sil),
blank_(blank) {}
void decodeBegin() override;
void decodeStep(const float* emissions, int T, int N) override;
void decodeEnd() override;
int nHypothesis() const;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconFreeDecoderOptions opt_;
LMPtr lm_;
std::vector<float> transitions_;
// All the hypotheses new candidates (can be larger than beamsize) proposed
// based on the ones from previous frame
std::vector<LexiconFreeDecoderState> candidates_;
// This vector is designed for efficient sorting and merging the candidates_,
// so instead of moving around objects, we only need to sort pointers
std::vector<LexiconFreeDecoderState*> candidatePtrs_;
// Best candidate score of current frame
double candidatesBestScore_;
// Index of silence label
int sil_;
// Index of blank label (for CTC)
int blank_;
// Vector of hypothesis for all the frames so far
std::unordered_map<int, std::vector<LexiconFreeDecoderState>> hyp_;
// These 2 variables are used for online decoding, for hypothesis pruning
int nDecodedFrames_; // Total number of decoded frames.
int nPrunedFrames_; // Total number of pruned frames from hyp_.
};
} // namespace text
} // namespace lib
} // namespace torchaudio
...@@ -2,7 +2,7 @@ import torchaudio ...@@ -2,7 +2,7 @@ import torchaudio
try: try:
torchaudio._extension._load_lib("libtorchaudio_decoder") torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import Hypothesis, LexiconDecoder, lexicon_decoder, download_pretrained_files from .ctc_decoder import Hypothesis, CTCDecoder, ctc_decoder, lexicon_decoder, download_pretrained_files
except ImportError as err: except ImportError as err:
raise ImportError( raise ImportError(
"flashlight decoder bindings are required to use this functionality. " "flashlight decoder bindings are required to use this functionality. "
...@@ -12,7 +12,8 @@ except ImportError as err: ...@@ -12,7 +12,8 @@ except ImportError as err:
__all__ = [ __all__ = [
"Hypothesis", "Hypothesis",
"LexiconDecoder", "CTCDecoder",
"ctc_decoder",
"lexicon_decoder", "lexicon_decoder",
"download_pretrained_files", "download_pretrained_files",
] ]
import itertools as it import itertools as it
import warnings
from collections import namedtuple from collections import namedtuple
from typing import Dict, List, Optional, Union, NamedTuple from typing import Dict, List, Optional, Union, NamedTuple
...@@ -8,7 +9,9 @@ from torchaudio._torchaudio_decoder import ( ...@@ -8,7 +9,9 @@ from torchaudio._torchaudio_decoder import (
_LM, _LM,
_KenLM, _KenLM,
_LexiconDecoder, _LexiconDecoder,
_LexiconFreeDecoder,
_LexiconDecoderOptions, _LexiconDecoderOptions,
_LexiconFreeDecoderOptions,
_SmearingMode, _SmearingMode,
_Trie, _Trie,
_Dictionary, _Dictionary,
...@@ -18,14 +21,14 @@ from torchaudio._torchaudio_decoder import ( ...@@ -18,14 +21,14 @@ from torchaudio._torchaudio_decoder import (
) )
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
__all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"] __all__ = ["Hypothesis", "CTCDecoder", "ctc_decoder", "lexicon_decoder", "download_pretrained_files"]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"]) _PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
class Hypothesis(NamedTuple): class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`. r"""Represents hypothesis generated by CTC beam search decoder :py:func`CTCDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where :ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where
`L` is the length of the output sequence `L` is the length of the output sequence
...@@ -40,8 +43,8 @@ class Hypothesis(NamedTuple): ...@@ -40,8 +43,8 @@ class Hypothesis(NamedTuple):
timesteps: torch.IntTensor timesteps: torch.IntTensor
class LexiconDecoder: class CTCDecoder:
"""torchaudio.prototype.ctc_decoder.LexiconDecoder() """torchaudio.prototype.ctc_decoder.CTCDecoder()
.. devices:: CPU .. devices:: CPU
...@@ -49,15 +52,15 @@ class LexiconDecoder: ...@@ -49,15 +52,15 @@ class LexiconDecoder:
Note: Note:
To build the decoder, please use factory function To build the decoder, please use factory function
:py:func:`lexicon_decoder`. :py:func:`ctc_decoder`.
Args: Args:
nbest (int): number of best decodings to return nbest (int): number of best decodings to return
lexicon (Dict): lexicon mapping of words to spellings lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon free decoder
word_dict (_Dictionary): dictionary of words word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens tokens_dict (_Dictionary): dictionary of tokens
lm (_LM): language model lm (_LM): language model
decoder_options (_LexiconDecoderOptions): parameters used for beam search decoding decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence sil_token (str): token corresponding to silence
unk_word (str): word corresponding to unknown unk_word (str): word corresponding to unknown
...@@ -66,11 +69,11 @@ class LexiconDecoder: ...@@ -66,11 +69,11 @@ class LexiconDecoder:
def __init__( def __init__(
self, self,
nbest: int, nbest: int,
lexicon: Dict, lexicon: Optional[Dict],
word_dict: _Dictionary, word_dict: _Dictionary,
tokens_dict: _Dictionary, tokens_dict: _Dictionary,
lm: _LM, lm: _LM,
decoder_options: _LexiconDecoderOptions, decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
blank_token: str, blank_token: str,
sil_token: str, sil_token: str,
unk_word: str, unk_word: str,
...@@ -78,33 +81,36 @@ class LexiconDecoder: ...@@ -78,33 +81,36 @@ class LexiconDecoder:
self.nbest = nbest self.nbest = nbest
self.word_dict = word_dict self.word_dict = word_dict
self.tokens_dict = tokens_dict self.tokens_dict = tokens_dict
unk_word = word_dict.get_index(unk_word)
self.blank = self.tokens_dict.get_index(blank_token) self.blank = self.tokens_dict.get_index(blank_token)
silence = self.tokens_dict.get_index(sil_token) silence = self.tokens_dict.get_index(sil_token)
vocab_size = self.tokens_dict.index_size() if lexicon:
trie = _Trie(vocab_size, silence) unk_word = word_dict.get_index(unk_word)
start_state = lm.start(False)
vocab_size = self.tokens_dict.index_size()
for word, spellings in lexicon.items(): trie = _Trie(vocab_size, silence)
word_idx = self.word_dict.get_index(word) start_state = lm.start(False)
_, score = lm.score(start_state, word_idx)
for spelling in spellings: for word, spellings in lexicon.items():
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling] word_idx = self.word_dict.get_index(word)
trie.insert(spelling_idx, word_idx, score) _, score = lm.score(start_state, word_idx)
trie.smear(_SmearingMode.MAX) for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
self.decoder = _LexiconDecoder( trie.insert(spelling_idx, word_idx, score)
decoder_options, trie.smear(_SmearingMode.MAX)
trie,
lm, self.decoder = _LexiconDecoder(
silence, decoder_options,
self.blank, trie,
unk_word, lm,
[], silence,
False, # word level LM self.blank,
) unk_word,
[],
False, # word level LM
)
else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, [])
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs)) idxs = (g[0] for g in it.groupby(idxs))
...@@ -153,6 +159,7 @@ class LexiconDecoder: ...@@ -153,6 +159,7 @@ class LexiconDecoder:
float_bytes = 4 float_bytes = 4
hypos = [] hypos = []
for b in range(B): for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0) emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
...@@ -186,8 +193,8 @@ class LexiconDecoder: ...@@ -186,8 +193,8 @@ class LexiconDecoder:
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs] return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
def lexicon_decoder( def ctc_decoder(
lexicon: str, lexicon: Optional[str],
tokens: Union[str, List[str]], tokens: Union[str, List[str]],
lm: Optional[str] = None, lm: Optional[str] = None,
nbest: int = 1, nbest: int = 1,
...@@ -202,14 +209,15 @@ def lexicon_decoder( ...@@ -202,14 +209,15 @@ def lexicon_decoder(
blank_token: str = "-", blank_token: str = "-",
sil_token: str = "|", sil_token: str = "|",
unk_word: str = "<unk>", unk_word: str = "<unk>",
) -> LexiconDecoder: ) -> CTCDecoder:
""" """
Builds lexically constrained CTC beam search decoder from Builds lexically constrained CTC beam search decoder from
*Flashlight* [:footcite:`kahn2022flashlight`]. *Flashlight* [:footcite:`kahn2022flashlight`].
Args: Args:
lexicon (str): lexicon file containing the possible words and corresponding spellings. lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling Each line consists of a word and its space separated spelling. If `None`, uses lexicon free
decoding.
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line format is for tokens mapping to the same index to be on the same line
lm (str or None, optional): file containing language model, or `None` if not using a language model lm (str or None, optional): file containing language model, or `None` if not using a language model
...@@ -228,34 +236,51 @@ def lexicon_decoder( ...@@ -228,34 +236,51 @@ def lexicon_decoder(
unk_word (str, optional): word corresponding to unknown (Default: "<unk>") unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
Returns: Returns:
LexiconDecoder: decoder CTCDecoder: decoder
Example Example
>>> decoder = lexicon_decoder( >>> decoder = ctc_decoder(
>>> lexicon="lexicon.txt", >>> lexicon="lexicon.txt",
>>> tokens="tokens.txt", >>> tokens="tokens.txt",
>>> lm="kenlm.bin", >>> lm="kenlm.bin",
>>> ) >>> )
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
""" """
lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
tokens_dict = _Dictionary(tokens) tokens_dict = _Dictionary(tokens)
decoder_options = _LexiconDecoderOptions( if lexicon is not None:
beam_size=beam_size, lexicon = _load_words(lexicon)
beam_size_token=beam_size_token or tokens_dict.index_size(), word_dict = _create_word_dict(lexicon)
beam_threshold=beam_threshold, lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
lm_weight=lm_weight,
word_score=word_score, decoder_options = _LexiconDecoderOptions(
unk_score=unk_score, beam_size=beam_size,
sil_score=sil_score, beam_size_token=beam_size_token or tokens_dict.index_size(),
log_add=log_add, beam_threshold=beam_threshold,
criterion_type=_CriterionType.CTC, lm_weight=lm_weight,
) word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
else:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconFreeDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
return LexiconDecoder( return CTCDecoder(
nbest=nbest, nbest=nbest,
lexicon=lexicon, lexicon=lexicon,
word_dict=word_dict, word_dict=word_dict,
...@@ -268,6 +293,44 @@ def lexicon_decoder( ...@@ -268,6 +293,44 @@ def lexicon_decoder(
) )
def lexicon_decoder(
lexicon: str,
tokens: Union[str, List[str]],
lm: Optional[str] = None,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> CTCDecoder:
warnings.warn("`lexicon_decoder` is now deprecated. Please use `ctc_decoder` instead.")
return ctc_decoder(
lexicon=lexicon,
tokens=tokens,
lm=lm,
nbest=nbest,
beam_size=beam_size,
beam_size_token=beam_size_token,
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
blank_token=blank_token,
sil_token=sil_token,
unk_word=unk_word,
)
def _get_filenames(model: str) -> _PretrainedFiles: def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]: if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError( raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment