Commit 97ed428d authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add lexicon free CTC decoder (#2342)

Summary:
Add support for lexicon free decoding based on [fairseq's](https://github.com/pytorch/fairseq/blob/main/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53) implementation. Reached numerical parity with fairseq's decoder in offline experimentation

Follow ups
- Add pretrained LM support for lex free decoding
- Add example in tutorial
- Replace flashlight C++ source code with flashlight text submodule
- [optional] fairseq compatibility test

Pull Request resolved: https://github.com/pytorch/audio/pull/2342

Reviewed By: nateanl

Differential Revision: D35856104

Pulled By: carolineechen

fbshipit-source-id: b64286550984df906ebb747e82f6fb1f21948ac7
parent 7c249d17
......@@ -6,11 +6,11 @@ torchaudio.prototype.ctc_decoder
Decoder Class
-------------
LexiconDecoder
~~~~~~~~~~~~~~
CTCDecoder
~~~~~~~~~~
.. autoclass:: LexiconDecoder
.. autoclass:: CTCDecoder
.. automethod:: __call__
......@@ -24,10 +24,10 @@ Hypothesis
Factory Function
----------------
lexicon_decoder
~~~~~~~~~~~~~~~
ctc_decoder
~~~~~~~~~~~
.. autoclass:: lexicon_decoder
.. autoclass:: ctc_decoder
Utility Function
----------------
......
......@@ -233,7 +233,7 @@ print(files)
# Beam Search Decoder
# ~~~~~~~~~~~~~~~~~~~
# The decoder can be constructed using the factory function
# :py:func:`lexicon_decoder <torchaudio.prototype.ctc_decoder.lexicon_decoder>`.
# :py:func:`ctc_decoder <torchaudio.prototype.ctc_decoder.ctc_decoder>`.
# In addition to the previously mentioned components, it also takes in various beam
# search decoding parameters and token/word parameters.
#
......@@ -241,12 +241,12 @@ print(files)
# `lm` parameter.
#
from torchaudio.prototype.ctc_decoder import lexicon_decoder
from torchaudio.prototype.ctc_decoder import ctc_decoder
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
beam_search_decoder = lexicon_decoder(
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
......@@ -395,7 +395,7 @@ plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
# In this section, we go a little bit more in depth about some different
# parameters and tradeoffs. For the full list of customizable parameters,
# please refer to the
# :py:func:`documentation <torchaudio.prototype.ctc_decoder.lexicon_decoder>`.
# :py:func:`documentation <torchaudio.prototype.ctc_decoder.ctc_decoder>`.
#
......@@ -450,7 +450,7 @@ for i in range(3):
beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes:
beam_search_decoder = lexicon_decoder(
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
......@@ -476,7 +476,7 @@ num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens:
beam_search_decoder = lexicon_decoder(
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
......@@ -503,7 +503,7 @@ for beam_size_token in beam_size_tokens:
beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds:
beam_search_decoder = lexicon_decoder(
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
......@@ -529,7 +529,7 @@ for beam_threshold in beam_thresholds:
lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights:
beam_search_decoder = lexicon_decoder(
beam_search_decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
......
......@@ -9,10 +9,10 @@ import pytest
],
)
def test_decoder_from_pretrained(model, expected, emissions):
from torchaudio.prototype.ctc_decoder import lexicon_decoder, download_pretrained_files
from torchaudio.prototype.ctc_decoder import ctc_decoder, download_pretrained_files
pretrained_files = download_pretrained_files(model)
decoder = lexicon_decoder(
decoder = ctc_decoder(
lexicon=pretrained_files.lexicon,
tokens=pretrained_files.tokens,
lm=pretrained_files.lm,
......
\data\
ngram 1=8
ngram 2=8
ngram 3=8
ngram 4=8
ngram 5=8
\1-grams:
-1.146128 <unk> 0
0 <s> -0.30103
-0.8731268 </s> 0
-0.70679533 f -0.30103
-0.70679533 o -0.30103
-0.8731268 b -0.30103
-0.8731268 a -0.30103
-0.8731268 r -0.30103
\2-grams:
-0.24644431 r </s> 0
-0.22314323 <s> f -0.30103
-0.57694924 o f -0.30103
-0.22314323 f o -0.30103
-0.57694924 o o -0.30103
-0.6314696 o b -0.30103
-0.24644431 b a -0.30103
-0.24644431 a r -0.30103
\3-grams:
-0.105970904 a r </s> 0
-0.41743615 o o f -0.30103
-0.097394995 <s> f o -0.30103
-0.097394995 o f o -0.30103
-0.19898036 f o o -0.30103
-0.43555236 o o b -0.30103
-0.105970904 o b a -0.30103
-0.105970904 b a r -0.30103
\4-grams:
-0.049761247 b a r </s> 0
-0.4462542 f o o f -0.30103
-0.045972984 o o f o -0.30103
-0.08819265 <s> f o o -0.30103
-0.08819265 o f o o -0.30103
-0.286727 f o o b -0.30103
-0.049761247 o o b a -0.30103
-0.049761247 o b a r -0.30103
\5-grams:
-0.02416831 o b a r </s>
-0.36759996 <s> f o o f
-0.022378458 f o o f o
-0.041861475 o o f o o
-0.29381964 <s> f o o b
-0.12011856 o f o o b
-0.02416831 f o o b a
-0.02416831 o o b a r
\end\
......@@ -10,18 +10,25 @@ from torchaudio_unittest.common_utils import (
)
NUM_TOKENS = 8
@skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, use_lm=True, **kwargs):
from torchaudio.prototype.ctc_decoder import lexicon_decoder
def _get_decoder(self, tokens=None, use_lm=True, use_lexicon=True, **kwargs):
from torchaudio.prototype.ctc_decoder import ctc_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa") if use_lm else None
if use_lexicon:
lexicon_file = get_asset_path("decoder/lexicon.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa") if use_lm else None
else:
lexicon_file = None
kenlm_file = get_asset_path("decoder/kenlm_char.arpa") if use_lm else None
if tokens is None:
tokens = get_asset_path("decoder/tokens.txt")
return lexicon_decoder(
return ctc_decoder(
lexicon=lexicon_file,
tokens=tokens,
lm=kenlm_file,
......@@ -29,7 +36,7 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
)
def _get_emissions(self):
B, T, N = 4, 15, 10
B, T, N = 4, 15, NUM_TOKENS
torch.manual_seed(0)
emissions = torch.rand(B, T, N)
......@@ -41,39 +48,46 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
itertools.product(
[get_asset_path("decoder/tokens.txt"), ["-", "|", "f", "o", "b", "a", "r"]],
[True, False],
[True, False],
)
),
)
def test_construct_decoder(self, tokens, use_lm):
self._get_decoder(tokens=tokens, use_lm=use_lm)
def test_no_lm_decoder(self):
"""Check that using no LM produces the same result as using an LM with 0 lm_weight"""
kenlm_decoder = self._get_decoder(lm_weight=0)
zerolm_decoder = self._get_decoder(use_lm=False)
def test_construct_decoder(self, tokens, use_lm, use_lexicon):
self._get_decoder(tokens=tokens, use_lm=use_lm, use_lexicon=use_lexicon)
@parameterized.expand(
[(True,), (False,)],
)
def test_shape(self, use_lexicon):
emissions = self._get_emissions()
kenlm_results = kenlm_decoder(emissions)
zerolm_results = zerolm_decoder(emissions)
self.assertEqual(kenlm_results, zerolm_results)
def test_shape(self):
emissions = self._get_emissions()
decoder = self._get_decoder()
decoder = self._get_decoder(use_lexicon=use_lexicon)
results = decoder(emissions)
self.assertEqual(len(results), emissions.shape[0])
def test_timesteps_shape(self):
@parameterized.expand(
[(True,), (False,)],
)
def test_timesteps_shape(self, use_lexicon):
"""Each token should correspond with a timestep"""
emissions = self._get_emissions()
decoder = self._get_decoder()
decoder = self._get_decoder(use_lexicon=use_lexicon)
results = decoder(emissions)
for i in range(emissions.shape[0]):
result = results[i][0]
self.assertEqual(result.tokens.shape, result.timesteps.shape)
def test_no_lm_decoder(self):
"""Check that using no LM produces the same result as using an LM with 0 lm_weight"""
kenlm_decoder = self._get_decoder(lm_weight=0)
zerolm_decoder = self._get_decoder(use_lm=False)
emissions = self._get_emissions()
kenlm_results = kenlm_decoder(emissions)
zerolm_results = zerolm_decoder(emissions)
self.assertEqual(kenlm_results, zerolm_results)
def test_get_timesteps(self):
unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3])
decoder = self._get_decoder()
......
......@@ -137,6 +137,7 @@ if (BUILD_CTC_DECODER)
set(
LIBTORCHAUDIO_DECODER_SOURCES
decoder/src/decoder/LexiconDecoder.cpp
decoder/src/decoder/LexiconFreeDecoder.cpp
decoder/src/decoder/Trie.cpp
decoder/src/decoder/Utils.cpp
decoder/src/decoder/lm/KenLM.cpp
......
#include <torch/extension.h>
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h"
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
......@@ -103,6 +104,22 @@ std::vector<DecodeResult> LexiconDecoder_decode(
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}
void LexiconFreeDecoder_decodeStep(
LexiconFreeDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
decoder.decodeStep(reinterpret_cast<const float*>(emissions), T, N);
}
std::vector<DecodeResult> LexiconFreeDecoder_decode(
LexiconFreeDecoder& decoder,
uintptr_t emissions,
int T,
int N) {
return decoder.decode(reinterpret_cast<const float*>(emissions), T, N);
}
void Dictionary_addEntry_0(
Dictionary& dict,
const std::string& entry,
......@@ -191,6 +208,34 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
.def_readwrite("log_add", &LexiconDecoderOptions::logAdd)
.def_readwrite("criterion_type", &LexiconDecoderOptions::criterionType);
py::class_<LexiconFreeDecoderOptions>(m, "_LexiconFreeDecoderOptions")
.def(
py::init<
const int,
const int,
const double,
const double,
const double,
const bool,
const CriterionType>(),
"beam_size"_a,
"beam_size_token"_a,
"beam_threshold"_a,
"lm_weight"_a,
"sil_score"_a,
"log_add"_a,
"criterion_type"_a)
.def_readwrite("beam_size", &LexiconFreeDecoderOptions::beamSize)
.def_readwrite(
"beam_size_token", &LexiconFreeDecoderOptions::beamSizeToken)
.def_readwrite(
"beam_threshold", &LexiconFreeDecoderOptions::beamThreshold)
.def_readwrite("lm_weight", &LexiconFreeDecoderOptions::lmWeight)
.def_readwrite("sil_score", &LexiconFreeDecoderOptions::silScore)
.def_readwrite("log_add", &LexiconFreeDecoderOptions::logAdd)
.def_readwrite(
"criterion_type", &LexiconFreeDecoderOptions::criterionType);
py::class_<DecodeResult>(m, "_DecodeResult")
.def(py::init<int>(), "length"_a)
.def_readwrite("score", &DecodeResult::score)
......@@ -226,6 +271,31 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
"look_back"_a = 0)
.def("get_all_final_hypothesis", &LexiconDecoder::getAllFinalHypothesis);
py::class_<LexiconFreeDecoder>(m, "_LexiconFreeDecoder")
.def(py::init<
LexiconFreeDecoderOptions,
const LMPtr,
const int,
const int,
const std::vector<float>&>())
.def("decode_begin", &LexiconFreeDecoder::decodeBegin)
.def(
"decode_step",
&LexiconFreeDecoder_decodeStep,
"emissions"_a,
"T"_a,
"N"_a)
.def("decode_end", &LexiconFreeDecoder::decodeEnd)
.def("decode", &LexiconFreeDecoder_decode, "emissions"_a, "T"_a, "N"_a)
.def("prune", &LexiconFreeDecoder::prune, "look_back"_a = 0)
.def(
"get_best_hypothesis",
&LexiconFreeDecoder::getBestHypothesis,
"look_back"_a = 0)
.def(
"get_all_final_hypothesis",
&LexiconFreeDecoder::getAllFinalHypothesis);
py::class_<Dictionary>(m, "_Dictionary")
.def(py::init<>())
.def(py::init<const std::vector<std::string>&>(), "tkns"_a)
......
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include "torchaudio/csrc/decoder/src/decoder/LexiconFreeDecoder.h"
namespace torchaudio {
namespace lib {
namespace text {
void LexiconFreeDecoder::decodeBegin() {
hyp_.clear();
hyp_.emplace(0, std::vector<LexiconFreeDecoderState>());
/* note: the lm reset itself with :start() */
hyp_[0].emplace_back(0.0, lm_->start(0), nullptr, sil_);
nDecodedFrames_ = 0;
nPrunedFrames_ = 0;
}
void LexiconFreeDecoder::decodeStep(const float* emissions, int T, int N) {
int startFrame = nDecodedFrames_ - nPrunedFrames_;
// Extend hyp_ buffer
if (hyp_.size() < startFrame + T + 2) {
for (int i = hyp_.size(); i < startFrame + T + 2; i++) {
hyp_.emplace(i, std::vector<LexiconFreeDecoderState>());
}
}
std::vector<size_t> idx(N);
// Looping over all the frames
for (int t = 0; t < T; t++) {
std::iota(idx.begin(), idx.end(), 0);
if (N > opt_.beamSizeToken) {
std::partial_sort(
idx.begin(),
idx.begin() + opt_.beamSizeToken,
idx.end(),
[&t, &N, &emissions](const size_t& l, const size_t& r) {
return emissions[t * N + l] > emissions[t * N + r];
});
}
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp : hyp_[startFrame + t]) {
const int prevIdx = prevHyp.token;
for (int r = 0; r < std::min(opt_.beamSizeToken, N); ++r) {
int n = idx[r];
double amScore = emissions[t * N + n];
if (nDecodedFrames_ + t > 0 &&
opt_.criterionType == CriterionType::ASG) {
amScore += transitions_[n * N + prevIdx];
}
double score = prevHyp.score + emissions[t * N + n];
if (n == sil_) {
score += opt_.silScore;
}
if ((opt_.criterionType == CriterionType::ASG && n != prevIdx) ||
(opt_.criterionType == CriterionType::CTC && n != blank_ &&
(n != prevIdx || prevHyp.prevBlank))) {
auto lmStateScorePair = lm_->score(prevHyp.lmState, n);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore + lmScore);
} else if (opt_.criterionType == CriterionType::CTC && n == blank_) {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
true, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
} else {
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
score,
prevHyp.lmState,
&prevHyp,
n,
false, // prevBlank
prevHyp.amScore + amScore,
prevHyp.lmScore);
}
}
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[startFrame + t + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
false);
updateLMCache(lm_, hyp_[startFrame + t + 1]);
}
nDecodedFrames_ += T;
}
void LexiconFreeDecoder::decodeEnd() {
candidatesReset(candidatesBestScore_, candidates_, candidatePtrs_);
for (const LexiconFreeDecoderState& prevHyp :
hyp_[nDecodedFrames_ - nPrunedFrames_]) {
const LMStatePtr& prevLmState = prevHyp.lmState;
auto lmStateScorePair = lm_->finish(prevLmState);
auto lmScore = lmStateScorePair.second;
candidatesAdd(
candidates_,
candidatesBestScore_,
opt_.beamThreshold,
prevHyp.score + opt_.lmWeight * lmScore,
lmStateScorePair.first,
&prevHyp,
sil_,
false, // prevBlank
prevHyp.amScore,
prevHyp.lmScore + lmScore);
}
candidatesStore(
candidates_,
candidatePtrs_,
hyp_[nDecodedFrames_ - nPrunedFrames_ + 1],
opt_.beamSize,
candidatesBestScore_ - opt_.beamThreshold,
opt_.logAdd,
true);
++nDecodedFrames_;
}
std::vector<DecodeResult> LexiconFreeDecoder::getAllFinalHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return getAllHypothesis(hyp_.find(finalFrame)->second, finalFrame);
}
DecodeResult LexiconFreeDecoder::getBestHypothesis(int lookBack) const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
return getHypothesis(bestNode, nDecodedFrames_ - nPrunedFrames_ - lookBack);
}
int LexiconFreeDecoder::nHypothesis() const {
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
return hyp_.find(finalFrame)->second.size();
}
int LexiconFreeDecoder::nDecodedFramesInBuffer() const {
return nDecodedFrames_ - nPrunedFrames_ + 1;
}
void LexiconFreeDecoder::prune(int lookBack) {
if (nDecodedFrames_ - nPrunedFrames_ - lookBack < 1) {
return; // Not enough decoded frames to prune
}
/* (1) Find the last emitted word in the best path */
int finalFrame = nDecodedFrames_ - nPrunedFrames_;
const LexiconFreeDecoderState* bestNode =
findBestAncestor(hyp_.find(finalFrame)->second, lookBack);
if (!bestNode) {
return; // Not enough decoded frames to prune
}
int startFrame = nDecodedFrames_ - nPrunedFrames_ - lookBack;
if (startFrame < 1) {
return; // Not enough decoded frames to prune
}
/* (2) Move things from back of hyp_ to front and normalize scores */
pruneAndNormalize(hyp_, startFrame, lookBack);
nPrunedFrames_ = nDecodedFrames_ - lookBack;
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include <unordered_map>
#include "torchaudio/csrc/decoder/src/decoder/Decoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
struct LexiconFreeDecoderOptions {
int beamSize; // Maximum number of hypotheses we hold after each step
int beamSizeToken; // Maximum number of tokens we consider at each step
double beamThreshold; // Threshold to prune hypotheses
double lmWeight; // Weight of lm
double silScore; // Silence insertion score
bool logAdd;
CriterionType criterionType; // CTC or ASG
};
/**
* LexiconFreeDecoderState stores information for each hypothesis in the beam.
*/
struct LexiconFreeDecoderState {
double score; // Accumulated total score so far
LMStatePtr lmState; // Language model state
const LexiconFreeDecoderState* parent; // Parent hypothesis
int token; // Label of token
bool prevBlank; // If previous hypothesis is blank (for CTC only)
double amScore; // Accumulated AM score so far
double lmScore; // Accumulated LM score so far
LexiconFreeDecoderState(
const double score,
const LMStatePtr& lmState,
const LexiconFreeDecoderState* parent,
const int token,
const bool prevBlank = false,
const double amScore = 0,
const double lmScore = 0)
: score(score),
lmState(lmState),
parent(parent),
token(token),
prevBlank(prevBlank),
amScore(amScore),
lmScore(lmScore) {}
LexiconFreeDecoderState()
: score(0),
lmState(nullptr),
parent(nullptr),
token(-1),
prevBlank(false),
amScore(0.),
lmScore(0.) {}
int compareNoScoreStates(const LexiconFreeDecoderState* node) const {
int lmCmp = lmState->compare(node->lmState);
if (lmCmp != 0) {
return lmCmp > 0 ? 1 : -1;
} else if (token != node->token) {
return token > node->token ? 1 : -1;
} else if (prevBlank != node->prevBlank) {
return prevBlank > node->prevBlank ? 1 : -1;
}
return 0;
}
int getWord() const {
return -1;
}
bool isComplete() const {
return true;
}
};
/**
* Decoder implements a beam seach decoder that finds the word transcription
* W maximizing:
*
* AM(W) + lmWeight_ * log(P_{lm}(W)) + silScore_ * |{i| pi_i = <sil>}|
*
* where P_{lm}(W) is the language model score, pi_i is the value for the i-th
* frame in the path leading to W and AM(W) is the (unnormalized) acoustic model
* score of the transcription W. We are allowed to generate words from all the
* possible combinations of tokens.
*/
class LexiconFreeDecoder : public Decoder {
public:
LexiconFreeDecoder(
LexiconFreeDecoderOptions opt,
const LMPtr& lm,
const int sil,
const int blank,
const std::vector<float>& transitions)
: opt_(std::move(opt)),
lm_(lm),
transitions_(transitions),
sil_(sil),
blank_(blank) {}
void decodeBegin() override;
void decodeStep(const float* emissions, int T, int N) override;
void decodeEnd() override;
int nHypothesis() const;
void prune(int lookBack = 0) override;
int nDecodedFramesInBuffer() const override;
DecodeResult getBestHypothesis(int lookBack = 0) const override;
std::vector<DecodeResult> getAllFinalHypothesis() const override;
protected:
LexiconFreeDecoderOptions opt_;
LMPtr lm_;
std::vector<float> transitions_;
// All the hypotheses new candidates (can be larger than beamsize) proposed
// based on the ones from previous frame
std::vector<LexiconFreeDecoderState> candidates_;
// This vector is designed for efficient sorting and merging the candidates_,
// so instead of moving around objects, we only need to sort pointers
std::vector<LexiconFreeDecoderState*> candidatePtrs_;
// Best candidate score of current frame
double candidatesBestScore_;
// Index of silence label
int sil_;
// Index of blank label (for CTC)
int blank_;
// Vector of hypothesis for all the frames so far
std::unordered_map<int, std::vector<LexiconFreeDecoderState>> hyp_;
// These 2 variables are used for online decoding, for hypothesis pruning
int nDecodedFrames_; // Total number of decoded frames.
int nPrunedFrames_; // Total number of pruned frames from hyp_.
};
} // namespace text
} // namespace lib
} // namespace torchaudio
......@@ -2,7 +2,7 @@ import torchaudio
try:
torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import Hypothesis, LexiconDecoder, lexicon_decoder, download_pretrained_files
from .ctc_decoder import Hypothesis, CTCDecoder, ctc_decoder, lexicon_decoder, download_pretrained_files
except ImportError as err:
raise ImportError(
"flashlight decoder bindings are required to use this functionality. "
......@@ -12,7 +12,8 @@ except ImportError as err:
__all__ = [
"Hypothesis",
"LexiconDecoder",
"CTCDecoder",
"ctc_decoder",
"lexicon_decoder",
"download_pretrained_files",
]
import itertools as it
import warnings
from collections import namedtuple
from typing import Dict, List, Optional, Union, NamedTuple
......@@ -8,7 +9,9 @@ from torchaudio._torchaudio_decoder import (
_LM,
_KenLM,
_LexiconDecoder,
_LexiconFreeDecoder,
_LexiconDecoderOptions,
_LexiconFreeDecoderOptions,
_SmearingMode,
_Trie,
_Dictionary,
......@@ -18,14 +21,14 @@ from torchaudio._torchaudio_decoder import (
)
from torchaudio.utils import download_asset
__all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"]
__all__ = ["Hypothesis", "CTCDecoder", "ctc_decoder", "lexicon_decoder", "download_pretrained_files"]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`.
r"""Represents hypothesis generated by CTC beam search decoder :py:func`CTCDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where
`L` is the length of the output sequence
......@@ -40,8 +43,8 @@ class Hypothesis(NamedTuple):
timesteps: torch.IntTensor
class LexiconDecoder:
"""torchaudio.prototype.ctc_decoder.LexiconDecoder()
class CTCDecoder:
"""torchaudio.prototype.ctc_decoder.CTCDecoder()
.. devices:: CPU
......@@ -49,15 +52,15 @@ class LexiconDecoder:
Note:
To build the decoder, please use factory function
:py:func:`lexicon_decoder`.
:py:func:`ctc_decoder`.
Args:
nbest (int): number of best decodings to return
lexicon (Dict): lexicon mapping of words to spellings
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon free decoder
word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens
lm (_LM): language model
decoder_options (_LexiconDecoderOptions): parameters used for beam search decoding
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence
unk_word (str): word corresponding to unknown
......@@ -66,11 +69,11 @@ class LexiconDecoder:
def __init__(
self,
nbest: int,
lexicon: Dict,
lexicon: Optional[Dict],
word_dict: _Dictionary,
tokens_dict: _Dictionary,
lm: _LM,
decoder_options: _LexiconDecoderOptions,
decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
blank_token: str,
sil_token: str,
unk_word: str,
......@@ -78,33 +81,36 @@ class LexiconDecoder:
self.nbest = nbest
self.word_dict = word_dict
self.tokens_dict = tokens_dict
unk_word = word_dict.get_index(unk_word)
self.blank = self.tokens_dict.get_index(blank_token)
silence = self.tokens_dict.get_index(sil_token)
vocab_size = self.tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
self.decoder = _LexiconDecoder(
decoder_options,
trie,
lm,
silence,
self.blank,
unk_word,
[],
False, # word level LM
)
if lexicon:
unk_word = word_dict.get_index(unk_word)
vocab_size = self.tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
self.decoder = _LexiconDecoder(
decoder_options,
trie,
lm,
silence,
self.blank,
unk_word,
[],
False, # word level LM
)
else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, [])
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
......@@ -153,6 +159,7 @@ class LexiconDecoder:
float_bytes = 4
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
......@@ -186,8 +193,8 @@ class LexiconDecoder:
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
def lexicon_decoder(
lexicon: str,
def ctc_decoder(
lexicon: Optional[str],
tokens: Union[str, List[str]],
lm: Optional[str] = None,
nbest: int = 1,
......@@ -202,14 +209,15 @@ def lexicon_decoder(
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> LexiconDecoder:
) -> CTCDecoder:
"""
Builds lexically constrained CTC beam search decoder from
*Flashlight* [:footcite:`kahn2022flashlight`].
Args:
lexicon (str): lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling
lexicon (str or None): lexicon file containing the possible words and corresponding spellings.
Each line consists of a word and its space separated spelling. If `None`, uses lexicon free
decoding.
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
lm (str or None, optional): file containing language model, or `None` if not using a language model
......@@ -228,34 +236,51 @@ def lexicon_decoder(
unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
Returns:
LexiconDecoder: decoder
CTCDecoder: decoder
Example
>>> decoder = lexicon_decoder(
>>> decoder = ctc_decoder(
>>> lexicon="lexicon.txt",
>>> tokens="tokens.txt",
>>> lm="kenlm.bin",
>>> )
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
"""
lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
tokens_dict = _Dictionary(tokens)
decoder_options = _LexiconDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
if lexicon is not None:
lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
else:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconFreeDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
return LexiconDecoder(
return CTCDecoder(
nbest=nbest,
lexicon=lexicon,
word_dict=word_dict,
......@@ -268,6 +293,44 @@ def lexicon_decoder(
)
def lexicon_decoder(
lexicon: str,
tokens: Union[str, List[str]],
lm: Optional[str] = None,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> CTCDecoder:
warnings.warn("`lexicon_decoder` is now deprecated. Please use `ctc_decoder` instead.")
return ctc_decoder(
lexicon=lexicon,
tokens=tokens,
lm=lm,
nbest=nbest,
beam_size=beam_size,
beam_size_token=beam_size_token,
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
blank_token=blank_token,
sil_token=sil_token,
unk_word=unk_word,
)
def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment