"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "3a278d701d3a0bba25ad52891653330ece2cb472"
Commit 4c3fa875 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add no lm support for CTC decoder (#2174)

Summary:
Add support for CTC lexicon decoder without LM support by adding a non language model `ZeroLM` that returns score 0 for everything. Generalize the decoder class/API a bit to support this, adding it as an option for the kenlm decoder at the moment (will likely be separated out from kenlm when adding support for other kinds of LMs in the future)

Pull Request resolved: https://github.com/pytorch/audio/pull/2174

Reviewed By: hwangjeff, nateanl

Differential Revision: D33798674

Pulled By: carolineechen

fbshipit-source-id: ef8265f1d046011b143597b3b7c691566b08dcde
parent 2cb87c6b
...@@ -24,7 +24,7 @@ Hypothesis ...@@ -24,7 +24,7 @@ Hypothesis
Factory Function Factory Function
---------------- ----------------
kenlm_lexicon_decoder lexicon_decoder
~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: kenlm_lexicon_decoder .. autoclass:: lexicon_decoder
...@@ -4,7 +4,7 @@ from typing import Optional ...@@ -4,7 +4,7 @@ from typing import Optional
import torch import torch
import torchaudio import torchaudio
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder from torchaudio.prototype.ctc_decoder import lexicon_decoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -31,10 +31,10 @@ def run_inference(args): ...@@ -31,10 +31,10 @@ def run_inference(args):
kenlm_file = f"{hub_dir}/kenlm.bin" kenlm_file = f"{hub_dir}/kenlm.bin"
_download_files(lexicon_file, kenlm_file) _download_files(lexicon_file, kenlm_file)
decoder = kenlm_lexicon_decoder( decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
kenlm=kenlm_file, lm=None,
nbest=1, nbest=1,
beam_size=1500, beam_size=1500,
beam_size_token=None, beam_size_token=None,
......
...@@ -177,20 +177,23 @@ torch.hub.download_url_to_file(kenlm_url, kenlm_file) ...@@ -177,20 +177,23 @@ torch.hub.download_url_to_file(kenlm_url, kenlm_file)
# ----------------------------- # -----------------------------
# #
# The decoder can be constructed using the factory function # The decoder can be constructed using the factory function
# :py:func:`kenlm_lexicon_decoder <torchaudio.prototype.ctc_decoder.kenlm_lexicon_decoder>`. # :py:func:`lexicon_decoder <torchaudio.prototype.ctc_decoder.lexicon_decoder>`.
# In addition to the previously mentioned components, it also takes in various beam # In addition to the previously mentioned components, it also takes in various beam
# search decoding parameters and token/word parameters. # search decoding parameters and token/word parameters.
# #
# This decoder can also be run without a language model by passing in `None` into the
# `lm` parameter.
#
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder from torchaudio.prototype.ctc_decoder import lexicon_decoder
LM_WEIGHT = 3.23 LM_WEIGHT = 3.23
WORD_SCORE = -0.26 WORD_SCORE = -0.26
beam_search_decoder = kenlm_lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
kenlm=kenlm_file, lm=kenlm_file,
nbest=3, nbest=3,
beam_size=1500, beam_size=1500,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
...@@ -290,7 +293,7 @@ print(f"WER: {beam_search_wer}") ...@@ -290,7 +293,7 @@ print(f"WER: {beam_search_wer}")
# In this section, we go a little bit more in depth about some different # In this section, we go a little bit more in depth about some different
# parameters and tradeoffs. For the full list of customizable parameters, # parameters and tradeoffs. For the full list of customizable parameters,
# please refer to the # please refer to the
# :py:func:`documentation <torchaudio.prototype.ctc_decoder.kenlm_lexicon_decoder>`. # :py:func:`documentation <torchaudio.prototype.ctc_decoder.lexicon_decoder>`.
# #
...@@ -345,10 +348,10 @@ for i in range(3): ...@@ -345,10 +348,10 @@ for i in range(3):
beam_sizes = [1, 5, 50, 500] beam_sizes = [1, 5, 50, 500]
for beam_size in beam_sizes: for beam_size in beam_sizes:
beam_search_decoder = kenlm_lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
kenlm=kenlm_file, lm=kenlm_file,
beam_size=beam_size, beam_size=beam_size,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
word_score=WORD_SCORE, word_score=WORD_SCORE,
...@@ -371,10 +374,10 @@ num_tokens = len(tokens) ...@@ -371,10 +374,10 @@ num_tokens = len(tokens)
beam_size_tokens = [1, 5, 10, num_tokens] beam_size_tokens = [1, 5, 10, num_tokens]
for beam_size_token in beam_size_tokens: for beam_size_token in beam_size_tokens:
beam_search_decoder = kenlm_lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
kenlm=kenlm_file, lm=kenlm_file,
beam_size_token=beam_size_token, beam_size_token=beam_size_token,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
word_score=WORD_SCORE, word_score=WORD_SCORE,
...@@ -398,10 +401,10 @@ for beam_size_token in beam_size_tokens: ...@@ -398,10 +401,10 @@ for beam_size_token in beam_size_tokens:
beam_thresholds = [1, 5, 10, 25] beam_thresholds = [1, 5, 10, 25]
for beam_threshold in beam_thresholds: for beam_threshold in beam_thresholds:
beam_search_decoder = kenlm_lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
kenlm=kenlm_file, lm=kenlm_file,
beam_threshold=beam_threshold, beam_threshold=beam_threshold,
lm_weight=LM_WEIGHT, lm_weight=LM_WEIGHT,
word_score=WORD_SCORE, word_score=WORD_SCORE,
...@@ -424,10 +427,10 @@ for beam_threshold in beam_thresholds: ...@@ -424,10 +427,10 @@ for beam_threshold in beam_thresholds:
lm_weights = [0, LM_WEIGHT, 15] lm_weights = [0, LM_WEIGHT, 15]
for lm_weight in lm_weights: for lm_weight in lm_weights:
beam_search_decoder = kenlm_lexicon_decoder( beam_search_decoder = lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
kenlm=kenlm_file, lm=kenlm_file,
lm_weight=lm_weight, lm_weight=lm_weight,
word_score=WORD_SCORE, word_score=WORD_SCORE,
) )
......
import itertools
import torch import torch
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
...@@ -10,35 +12,57 @@ from torchaudio_unittest.common_utils import ( ...@@ -10,35 +12,57 @@ from torchaudio_unittest.common_utils import (
@skipIfNoCtcDecoder @skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None): def _get_decoder(self, tokens=None, use_lm=True, **kwargs):
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder from torchaudio.prototype.ctc_decoder import lexicon_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt") lexicon_file = get_asset_path("decoder/lexicon.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa") kenlm_file = get_asset_path("decoder/kenlm.arpa") if use_lm else None
if tokens is None: if tokens is None:
tokens = get_asset_path("decoder/tokens.txt") tokens = get_asset_path("decoder/tokens.txt")
return kenlm_lexicon_decoder( return lexicon_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
kenlm=kenlm_file, lm=kenlm_file,
**kwargs,
) )
@parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)]) def _get_emissions(self):
def test_construct_decoder(self, tokens):
self._get_decoder(tokens)
def test_shape(self):
B, T, N = 4, 15, 10 B, T, N = 4, 15, 10
torch.manual_seed(0) torch.manual_seed(0)
emissions = torch.rand(B, T, N) emissions = torch.rand(B, T, N)
return emissions
@parameterized.expand(
list(
itertools.product(
[get_asset_path("decoder/tokens.txt"), ["-", "|", "f", "o", "b", "a", "r"]],
[True, False],
)
),
)
def test_construct_decoder(self, tokens, use_lm):
self._get_decoder(tokens=tokens, use_lm=use_lm)
def test_no_lm_decoder(self):
"""Check that using no LM produces the same result as using an LM with 0 lm_weight"""
kenlm_decoder = self._get_decoder(lm_weight=0)
zerolm_decoder = self._get_decoder(use_lm=False)
emissions = self._get_emissions()
kenlm_results = kenlm_decoder(emissions)
zerolm_results = zerolm_decoder(emissions)
self.assertEqual(kenlm_results, zerolm_results)
def test_shape(self):
emissions = self._get_emissions()
decoder = self._get_decoder() decoder = self._get_decoder()
results = decoder(emissions)
self.assertEqual(len(results), B) results = decoder(emissions)
self.assertEqual(len(results), emissions.shape[0])
@parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)]) @parameterized.expand([(get_asset_path("decoder/tokens.txt"),), (["-", "|", "f", "o", "b", "a", "r"],)])
def test_index_to_tokens(self, tokens): def test_index_to_tokens(self, tokens):
......
...@@ -137,6 +137,7 @@ if (BUILD_CTC_DECODER) ...@@ -137,6 +137,7 @@ if (BUILD_CTC_DECODER)
decoder/src/decoder/Trie.cpp decoder/src/decoder/Trie.cpp
decoder/src/decoder/Utils.cpp decoder/src/decoder/Utils.cpp
decoder/src/decoder/lm/KenLM.cpp decoder/src/decoder/lm/KenLM.cpp
decoder/src/decoder/lm/ZeroLM.cpp
decoder/src/dictionary/Dictionary.cpp decoder/src/dictionary/Dictionary.cpp
decoder/src/dictionary/String.cpp decoder/src/dictionary/String.cpp
decoder/src/dictionary/System.cpp decoder/src/dictionary/System.cpp
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h" #include "torchaudio/csrc/decoder/src/decoder/LexiconDecoder.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h" #include "torchaudio/csrc/decoder/src/decoder/lm/KenLM.h"
#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h"
#include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h" #include "torchaudio/csrc/decoder/src/dictionary/Dictionary.h"
#include "torchaudio/csrc/decoder/src/dictionary/Utils.h" #include "torchaudio/csrc/decoder/src/dictionary/Utils.h"
...@@ -153,6 +154,8 @@ PYBIND11_MODULE(_torchaudio_decoder, m) { ...@@ -153,6 +154,8 @@ PYBIND11_MODULE(_torchaudio_decoder, m) {
"path"_a, "path"_a,
"usr_token_dict"_a); "usr_token_dict"_a);
py::class_<ZeroLM, ZeroLMPtr, LM>(m, "_ZeroLM").def(py::init<>());
py::enum_<CriterionType>(m, "_CriterionType") py::enum_<CriterionType>(m, "_CriterionType")
.value("ASG", CriterionType::ASG) .value("ASG", CriterionType::ASG)
.value("CTC", CriterionType::CTC); .value("CTC", CriterionType::CTC);
......
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#include "torchaudio/csrc/decoder/src/decoder/lm/ZeroLM.h"
#include <stdexcept>
namespace torchaudio {
namespace lib {
namespace text {
LMStatePtr ZeroLM::start(bool /* unused */) {
return std::make_shared<LMState>();
}
std::pair<LMStatePtr, float> ZeroLM::score(
const LMStatePtr& state /* unused */,
const int usrTokenIdx) {
return std::make_pair(state->child<LMState>(usrTokenIdx), 0.0);
}
std::pair<LMStatePtr, float> ZeroLM::finish(const LMStatePtr& state) {
return std::make_pair(state, 0.0);
}
} // namespace text
} // namespace lib
} // namespace torchaudio
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* This source code is licensed under the MIT-style license found in the
* https://github.com/flashlight/flashlight/blob/d385b2150872fd7bf106601475d8719a703fe9ee/LICENSE
*/
#pragma once
#include "torchaudio/csrc/decoder/src/decoder/lm/LM.h"
namespace torchaudio {
namespace lib {
namespace text {
/**
* ZeroLM is a dummy language model class, which mimics the behavior of a
* uni-gram language model but always returns 0 as score.
*/
class ZeroLM : public LM {
public:
LMStatePtr start(bool startWithNothing) override;
std::pair<LMStatePtr, float> score(
const LMStatePtr& state,
const int usrTokenIdx) override;
std::pair<LMStatePtr, float> finish(const LMStatePtr& state) override;
};
using ZeroLMPtr = std::shared_ptr<ZeroLM>;
} // namespace text
} // namespace lib
} // namespace torchaudio
...@@ -2,7 +2,7 @@ import torchaudio ...@@ -2,7 +2,7 @@ import torchaudio
try: try:
torchaudio._extension._load_lib("libtorchaudio_decoder") torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import Hypothesis, KenLMLexiconDecoder, kenlm_lexicon_decoder from .ctc_decoder import Hypothesis, LexiconDecoder, lexicon_decoder
except ImportError as err: except ImportError as err:
raise ImportError( raise ImportError(
"flashlight decoder bindings are required to use this functionality. " "flashlight decoder bindings are required to use this functionality. "
...@@ -12,6 +12,6 @@ except ImportError as err: ...@@ -12,6 +12,6 @@ except ImportError as err:
__all__ = [ __all__ = [
"Hypothesis", "Hypothesis",
"KenLMLexiconDecoder", "LexiconDecoder",
"kenlm_lexicon_decoder", "lexicon_decoder",
] ]
...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union, NamedTuple ...@@ -4,6 +4,7 @@ from typing import Dict, List, Optional, Union, NamedTuple
import torch import torch
from torchaudio._torchaudio_decoder import ( from torchaudio._torchaudio_decoder import (
_CriterionType, _CriterionType,
_LM,
_KenLM, _KenLM,
_LexiconDecoder, _LexiconDecoder,
_LexiconDecoderOptions, _LexiconDecoderOptions,
...@@ -12,14 +13,15 @@ from torchaudio._torchaudio_decoder import ( ...@@ -12,14 +13,15 @@ from torchaudio._torchaudio_decoder import (
_Dictionary, _Dictionary,
_create_word_dict, _create_word_dict,
_load_words, _load_words,
_ZeroLM,
) )
__all__ = ["Hypothesis", "KenLMLexiconDecoder", "kenlm_lexicon_decoder"] __all__ = ["Hypothesis", "LexiconDecoder", "lexicon_decoder"]
class Hypothesis(NamedTuple): class Hypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`KenLMLexiconDecoder`. r"""Represents hypothesis generated by CTC beam search decoder :py:func`LexiconDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs :ivar torch.LongTensor tokens: Predicted sequence of token IDs
:ivar List[str] words: List of predicted words :ivar List[str] words: List of predicted words
...@@ -30,12 +32,12 @@ class Hypothesis(NamedTuple): ...@@ -30,12 +32,12 @@ class Hypothesis(NamedTuple):
score: float score: float
class KenLMLexiconDecoder: class LexiconDecoder:
"""torchaudio.prototype.ctc_decoder.KenLMLexiconDecoder() """torchaudio.prototype.ctc_decoder.LexiconDecoder()
Note: Note:
To build the decoder, please use factory function To build the decoder, please use factory function
:py:func:`kenlm_lexicon_decoder`. :py:func:`lexicon_decoder`.
""" """
...@@ -45,24 +47,24 @@ class KenLMLexiconDecoder: ...@@ -45,24 +47,24 @@ class KenLMLexiconDecoder:
lexicon: Dict, lexicon: Dict,
word_dict: _Dictionary, word_dict: _Dictionary,
tokens_dict: _Dictionary, tokens_dict: _Dictionary,
kenlm: _KenLM, lm: _LM,
decoder_options: _LexiconDecoderOptions, decoder_options: _LexiconDecoderOptions,
blank_token: str, blank_token: str,
sil_token: str, sil_token: str,
unk_word: str, unk_word: str,
) -> None: ) -> None:
""" """
KenLM CTC Decoder with Lexicon constraint. CTC Decoder with Lexicon constraint.
Note: Note:
To build the decoder, please use the factory function kenlm_lexicon_decoder. To build the decoder, please use the factory function lexicon_decoder.
Args: Args:
nbest (int): number of best decodings to return nbest (int): number of best decodings to return
lexicon (Dict): lexicon mapping of words to spellings lexicon (Dict): lexicon mapping of words to spellings
word_dict (_Dictionary): dictionary of words word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens tokens_dict (_Dictionary): dictionary of tokens
kenlm (_KenLM): n-gram KenLM language model lm (_LM): language model
decoder_options (_LexiconDecoderOptions): parameters used for beam search decoding decoder_options (_LexiconDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence sil_token (str): token corresponding to silence
...@@ -79,11 +81,11 @@ class KenLMLexiconDecoder: ...@@ -79,11 +81,11 @@ class KenLMLexiconDecoder:
vocab_size = self.tokens_dict.index_size() vocab_size = self.tokens_dict.index_size()
trie = _Trie(vocab_size, silence) trie = _Trie(vocab_size, silence)
start_state = kenlm.start(False) start_state = lm.start(False)
for word, spellings in lexicon.items(): for word, spellings in lexicon.items():
word_idx = self.word_dict.get_index(word) word_idx = self.word_dict.get_index(word)
_, score = kenlm.score(start_state, word_idx) _, score = lm.score(start_state, word_idx)
for spelling in spellings: for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling] spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score) trie.insert(spelling_idx, word_idx, score)
...@@ -92,7 +94,7 @@ class KenLMLexiconDecoder: ...@@ -92,7 +94,7 @@ class KenLMLexiconDecoder:
self.decoder = _LexiconDecoder( self.decoder = _LexiconDecoder(
decoder_options, decoder_options,
trie, trie,
kenlm, lm,
silence, silence,
self.blank, self.blank,
unk_word, unk_word,
...@@ -160,10 +162,10 @@ class KenLMLexiconDecoder: ...@@ -160,10 +162,10 @@ class KenLMLexiconDecoder:
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs] return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
def kenlm_lexicon_decoder( def lexicon_decoder(
lexicon: str, lexicon: str,
tokens: Union[str, List[str]], tokens: Union[str, List[str]],
kenlm: str, lm: str = None,
nbest: int = 1, nbest: int = 1,
beam_size: int = 50, beam_size: int = 50,
beam_size_token: Optional[int] = None, beam_size_token: Optional[int] = None,
...@@ -176,7 +178,7 @@ def kenlm_lexicon_decoder( ...@@ -176,7 +178,7 @@ def kenlm_lexicon_decoder(
blank_token: str = "-", blank_token: str = "-",
sil_token: str = "|", sil_token: str = "|",
unk_word: str = "<unk>", unk_word: str = "<unk>",
) -> KenLMLexiconDecoder: ) -> LexiconDecoder:
""" """
Builds Ken LM CTC Lexicon Decoder with given parameters Builds Ken LM CTC Lexicon Decoder with given parameters
...@@ -185,7 +187,7 @@ def kenlm_lexicon_decoder( ...@@ -185,7 +187,7 @@ def kenlm_lexicon_decoder(
Each line consists of a word and its space separated spelling Each line consists of a word and its space separated spelling
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line format is for tokens mapping to the same index to be on the same line
kenlm (str): file containing languge model lm (str or None, optional): file containing language model, or `None` if not using a language model
nbest (int, optional): number of best decodings to return (Default: 1) nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50) beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
beam_size_token (int, optional): max number of tokens to consider at each decode step. beam_size_token (int, optional): max number of tokens to consider at each decode step.
...@@ -201,19 +203,19 @@ def kenlm_lexicon_decoder( ...@@ -201,19 +203,19 @@ def kenlm_lexicon_decoder(
unk_word (str, optional): word corresponding to unknown (Default: "<unk>") unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
Returns: Returns:
KenLMLexiconDecoder: decoder LexiconDecoder: decoder
Example Example
>>> decoder = kenlm_lexicon_decoder( >>> decoder = lexicon_decoder(
>>> lexicon="lexicon.txt", >>> lexicon="lexicon.txt",
>>> tokens="tokens.txt", >>> tokens="tokens.txt",
>>> kenlm="kenlm.bin", >>> lm="kenlm.bin",
>>> ) >>> )
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
""" """
lexicon = _load_words(lexicon) lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon) word_dict = _create_word_dict(lexicon)
kenlm = _KenLM(kenlm, word_dict) lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
tokens_dict = _Dictionary(tokens) tokens_dict = _Dictionary(tokens)
decoder_options = _LexiconDecoderOptions( decoder_options = _LexiconDecoderOptions(
...@@ -228,12 +230,12 @@ def kenlm_lexicon_decoder( ...@@ -228,12 +230,12 @@ def kenlm_lexicon_decoder(
criterion_type=_CriterionType.CTC, criterion_type=_CriterionType.CTC,
) )
return KenLMLexiconDecoder( return LexiconDecoder(
nbest=nbest, nbest=nbest,
lexicon=lexicon, lexicon=lexicon,
word_dict=word_dict, word_dict=word_dict,
tokens_dict=tokens_dict, tokens_dict=tokens_dict,
kenlm=kenlm, lm=lm,
decoder_options=decoder_options, decoder_options=decoder_options,
blank_token=blank_token, blank_token=blank_token,
sil_token=sil_token, sil_token=sil_token,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment