Commit 03a0d68e authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add NNLM support to CTC Decoder (#2528)

Summary:
Expose flashlight's LM and LMState classes to support decoding with custom language models, including NN LMs.

The `ctc_decoder` API is as follows
- To decode with KenLM, pass in KenLM language model path to `lm` variable
- To decode with custom LM, create Python class with `CTCDecoderLM` subclass, and pass in the class to `lm` variable. Additionally create a file of LM words listed in order of the LM index, with a word per line, and pass in the file to `lm_path`.
- To decode without a language model, set `lm` to `None` (default)

Validated against fairseq w2l decoder on sample LibriSpeech dataset and LM. Code for validation can be found [here](https://github.com/facebookresearch/fairseq/compare/main...carolineechen:fairseq:ctc-decoder). Also added unit tests to validate custom implementations of ZeroLM and KenLM, and also using a biased LM.

Follow ups:
- Train simple LM on LibriSpeech and demonstrate usage in tutorial or examples directory

cc jacobkahn

Pull Request resolved: https://github.com/pytorch/audio/pull/2528

Reviewed By: mthrok

Differential Revision: D38243802

Pulled By: carolineechen

fbshipit-source-id: 445e78f6c20bda655aabf819fc0f771fe68c73d7
parent c15eee23
...@@ -9,23 +9,31 @@ NUM_TOKENS = 8 ...@@ -9,23 +9,31 @@ NUM_TOKENS = 8
@skipIfNoCtcDecoder @skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, use_lm=True, use_lexicon=True, **kwargs): def _get_custom_kenlm(self, kenlm_file):
from torchaudio.models.decoder import ctc_decoder from .ctc_decoder_utils import CustomKenLM
if use_lexicon: dict_file = get_asset_path("decoder/lexicon.txt")
lexicon_file = get_asset_path("decoder/lexicon.txt") custom_lm = CustomKenLM(kenlm_file, dict_file)
kenlm_file = get_asset_path("decoder/kenlm.arpa") if use_lm else None return custom_lm
else:
lexicon_file = None
kenlm_file = get_asset_path("decoder/kenlm_char.arpa") if use_lm else None
def _get_biased_nnlm(self, dict_file, keyword):
from .ctc_decoder_utils import BiasedLM, CustomBiasedLM
model = BiasedLM(dict_file, keyword)
biased_lm = CustomBiasedLM(model, dict_file)
return biased_lm
def _get_decoder(self, tokens=None, lm=None, use_lexicon=True, **kwargs):
from torchaudio.models.decoder import ctc_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt") if use_lexicon else None
if tokens is None: if tokens is None:
tokens = get_asset_path("decoder/tokens.txt") tokens = get_asset_path("decoder/tokens.txt")
return ctc_decoder( return ctc_decoder(
lexicon=lexicon_file, lexicon=lexicon_file,
tokens=tokens, tokens=tokens,
lm=kenlm_file, lm=lm,
**kwargs, **kwargs,
) )
...@@ -40,13 +48,13 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -40,13 +48,13 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
list( list(
itertools.product( itertools.product(
[get_asset_path("decoder/tokens.txt"), ["-", "|", "f", "o", "b", "a", "r"]], [get_asset_path("decoder/tokens.txt"), ["-", "|", "f", "o", "b", "a", "r"]],
[True, False], [None, get_asset_path("decoder/kenlm.arpa")],
[True, False], [True, False],
) )
), ),
) )
def test_construct_decoder(self, tokens, use_lm, use_lexicon): def test_construct_basic_decoder(self, tokens, lm, use_lexicon):
self._get_decoder(tokens=tokens, use_lm=use_lm, use_lexicon=use_lexicon) self._get_decoder(tokens=tokens, lm=lm, use_lexicon=use_lexicon)
@parameterized.expand( @parameterized.expand(
[(True,), (False,)], [(True,), (False,)],
...@@ -72,14 +80,63 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): ...@@ -72,14 +80,63 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
self.assertEqual(result.tokens.shape, result.timesteps.shape) self.assertEqual(result.tokens.shape, result.timesteps.shape)
def test_no_lm_decoder(self): def test_no_lm_decoder(self):
"""Check that using no LM produces the same result as using an LM with 0 lm_weight""" """Check that the following produce the same result
kenlm_decoder = self._get_decoder(lm_weight=0) - using no LM (C++ based implementation)
zerolm_decoder = self._get_decoder(use_lm=False) - using no LM (Custom Python based wrapper)
- using a (Ken)LM with 0 weight
"""
from .ctc_decoder_utils import CustomZeroLM
emissions = self._get_emissions() emissions = self._get_emissions()
custom_zerolm = CustomZeroLM()
zerolm_decoder_custom = self._get_decoder(lm=custom_zerolm)
zerolm_decoder_cpp = self._get_decoder(lm=None)
kenlm_file = get_asset_path("decoder/kenlm.arpa")
kenlm_decoder = self._get_decoder(lm=kenlm_file, lm_weight=0)
zerolm_custom_results = zerolm_decoder_custom(emissions)
zerolm_cpp_results = zerolm_decoder_cpp(emissions)
kenlm_results = kenlm_decoder(emissions) kenlm_results = kenlm_decoder(emissions)
zerolm_results = zerolm_decoder(emissions)
self.assertEqual(kenlm_results, zerolm_results) self.assertEqual(zerolm_cpp_results, zerolm_custom_results)
self.assertEqual(zerolm_cpp_results, kenlm_results)
def test_custom_kenlm_decoder(self):
"""Check that creating a custom Python KenLM wrapper produces same results as C++ based KenLM"""
emissions = self._get_emissions()
kenlm_file = get_asset_path("decoder/kenlm.arpa")
custom_kenlm = self._get_custom_kenlm(kenlm_file)
kenlm_decoder_custom = self._get_decoder(lm=custom_kenlm)
kenlm_decoder_cpp = self._get_decoder(lm=kenlm_file)
kenlm_custom_results = kenlm_decoder_custom(emissions)
kenlm_cpp_results = kenlm_decoder_cpp(emissions)
self.assertEqual(kenlm_custom_results, kenlm_cpp_results)
@parameterized.expand(
[
(get_asset_path("decoder/nnlm_lex_dict.txt"), "foo", True),
(get_asset_path("decoder/nnlm_lexfree_dict.txt"), "f", False),
]
)
def test_custom_nnlm_decoder(self, lm_dict, keyword, use_lexicon):
"""Check that biased NNLM only produces biased words"""
emissions = self._get_emissions()
custom_nnlm = self._get_biased_nnlm(lm_dict, keyword)
nnlm_decoder = self._get_decoder(lm=custom_nnlm, lm_dict=lm_dict, use_lexicon=use_lexicon, lm_weight=10)
nnlm_results = nnlm_decoder(emissions)
if use_lexicon:
output = [result[0].words for result in nnlm_results]
else:
tokens = [nnlm_decoder.idxs_to_tokens(result[0].tokens) for result in nnlm_results]
output = [list(filter(("|").__ne__, t)) for t in tokens] # filter out silence characters
lens = [len(out) for out in output]
expected = [[keyword] * len for len in lens] # all of output should match the biased keyword
assert expected == output
def test_get_timesteps(self): def test_get_timesteps(self):
unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3]) unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3])
......
import torch
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
from torchaudio.models.decoder._ctc_decoder import _create_word_dict, _Dictionary, _KenLM, _load_words
class CustomZeroLM(CTCDecoderLM):
def __init__(self):
CTCDecoderLM.__init__(self)
def start(self, start_with_nothing: bool):
return CTCDecoderLMState()
def score(self, state: CTCDecoderLMState, token_index: int):
return (state.child(token_index), 0.0)
def finish(self, state: CTCDecoderLMState):
return (state, 0.0)
class CustomKenLM(CTCDecoderLM):
def __init__(self, kenlm_file, dict_file):
CTCDecoderLM.__init__(self)
kenlm_dict = _create_word_dict(_load_words(dict_file))
self.model = _KenLM(kenlm_file, kenlm_dict)
def start(self, start_with_nothing: bool):
return self.model.start(start_with_nothing)
def score(self, state: CTCDecoderLMState, token_index: int):
return self.model.score(state, token_index)
def finish(self, state: CTCDecoderLMState):
return self.model.finish(state)
class BiasedLM(torch.nn.Module):
def __init__(self, dict_file, keyword):
super(BiasedLM, self).__init__()
self.dictionary = _Dictionary(dict_file)
self.keyword = keyword
def forward(self, token_idx):
if self.dictionary.get_entry(token_idx) == self.keyword:
return torch.tensor(10)
elif self.dictionary.get_entry(token_idx) == "<unk>":
return torch.tensor(-torch.inf)
return torch.tensor(-10)
class CustomBiasedLM(CTCDecoderLM):
def __init__(self, model, dict_file):
CTCDecoderLM.__init__(self)
self.model = model
self.vocab = _Dictionary(dict_file)
self.eos = self.vocab.get_index("|")
self.states = {}
model.eval()
def start(self, start_with_nothing: bool = False):
state = CTCDecoderLMState()
with torch.no_grad():
score = self.model(self.eos)
self.states[state] = score
return state
def score(self, state: CTCDecoderLMState, token_index: int):
outstate = state.child(token_index)
if outstate not in self.states:
score = self.model(token_index)
self.states[outstate] = score
score = self.states[outstate]
return outstate, score
def finish(self, state: CTCDecoderLMState):
return self.score(state, self.eos)
...@@ -2,6 +2,8 @@ _INITIALIZED = False ...@@ -2,6 +2,8 @@ _INITIALIZED = False
_LAZILY_IMPORTED = [ _LAZILY_IMPORTED = [
"CTCHypothesis", "CTCHypothesis",
"CTCDecoder", "CTCDecoder",
"CTCDecoderLM",
"CTCDecoderLMState",
"ctc_decoder", "ctc_decoder",
"download_pretrained_files", "download_pretrained_files",
] ]
......
...@@ -17,7 +17,8 @@ try: ...@@ -17,7 +17,8 @@ try:
LexiconDecoderOptions as _LexiconDecoderOptions, LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder, LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions, LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as _LM, LM as CTCDecoderLM,
LMState as CTCDecoderLMState,
SmearingMode as _SmearingMode, SmearingMode as _SmearingMode,
Trie as _Trie, Trie as _Trie,
ZeroLM as _ZeroLM, ZeroLM as _ZeroLM,
...@@ -36,7 +37,8 @@ except Exception: ...@@ -36,7 +37,8 @@ except Exception:
LexiconDecoderOptions as _LexiconDecoderOptions, LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder, LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions, LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as _LM, LM as CTCDecoderLM,
LMState as CTCDecoderLMState,
SmearingMode as _SmearingMode, SmearingMode as _SmearingMode,
Trie as _Trie, Trie as _Trie,
ZeroLM as _ZeroLM, ZeroLM as _ZeroLM,
...@@ -48,12 +50,48 @@ except Exception: ...@@ -48,12 +50,48 @@ except Exception:
) )
__all__ = ["CTCHypothesis", "CTCDecoder", "ctc_decoder", "download_pretrained_files"] __all__ = [
"CTCHypothesis",
"CTCDecoder",
"CTCDecoderLM",
"CTCDecoderLMState",
"ctc_decoder",
"download_pretrained_files",
]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"]) _PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
vocab_size = tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
return trie
def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
word_dict = None
if lm_dict is not None:
word_dict = _Dictionary(lm_dict)
if lexicon and word_dict is None:
word_dict = _create_word_dict(lexicon)
elif not lexicon and word_dict is None and type(lm) == str:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
return word_dict
class CTCHypothesis(NamedTuple): class CTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func:`CTCDecoder`. r"""Represents hypothesis generated by CTC beam search decoder :py:func:`CTCDecoder`.
...@@ -89,7 +127,7 @@ class CTCDecoder: ...@@ -89,7 +127,7 @@ class CTCDecoder:
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
word_dict (_Dictionary): dictionary of words word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens tokens_dict (_Dictionary): dictionary of tokens
lm (_LM): language model lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence sil_token (str): token corresponding to silence
...@@ -102,7 +140,7 @@ class CTCDecoder: ...@@ -102,7 +140,7 @@ class CTCDecoder:
lexicon: Optional[Dict], lexicon: Optional[Dict],
word_dict: _Dictionary, word_dict: _Dictionary,
tokens_dict: _Dictionary, tokens_dict: _Dictionary,
lm: _LM, lm: CTCDecoderLM,
decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions], decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
blank_token: str, blank_token: str,
sil_token: str, sil_token: str,
...@@ -113,21 +151,12 @@ class CTCDecoder: ...@@ -113,21 +151,12 @@ class CTCDecoder:
self.tokens_dict = tokens_dict self.tokens_dict = tokens_dict
self.blank = self.tokens_dict.get_index(blank_token) self.blank = self.tokens_dict.get_index(blank_token)
silence = self.tokens_dict.get_index(sil_token) silence = self.tokens_dict.get_index(sil_token)
transitions = []
if lexicon: if lexicon:
trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
unk_word = word_dict.get_index(unk_word) unk_word = word_dict.get_index(unk_word)
token_lm = False # use word level LM
vocab_size = self.tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
self.decoder = _LexiconDecoder( self.decoder = _LexiconDecoder(
decoder_options, decoder_options,
...@@ -136,11 +165,11 @@ class CTCDecoder: ...@@ -136,11 +165,11 @@ class CTCDecoder:
silence, silence,
self.blank, self.blank,
unk_word, unk_word,
[], transitions,
False, # word level LM token_lm,
) )
else: else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, []) self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs)) idxs = (g[0] for g in it.groupby(idxs))
...@@ -194,7 +223,6 @@ class CTCDecoder: ...@@ -194,7 +223,6 @@ class CTCDecoder:
for b in range(B): for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0) emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, lengths[b], N) results = self.decoder.decode(emissions_ptr, lengths[b], N)
nbest_results = results[: self.nbest] nbest_results = results[: self.nbest]
...@@ -228,7 +256,8 @@ class CTCDecoder: ...@@ -228,7 +256,8 @@ class CTCDecoder:
def ctc_decoder( def ctc_decoder(
lexicon: Optional[str], lexicon: Optional[str],
tokens: Union[str, List[str]], tokens: Union[str, List[str]],
lm: Optional[str] = None, lm: Union[str, CTCDecoderLM] = None,
lm_dict: Optional[str] = None,
nbest: int = 1, nbest: int = 1,
beam_size: int = 50, beam_size: int = 50,
beam_size_token: Optional[int] = None, beam_size_token: Optional[int] = None,
...@@ -251,7 +280,12 @@ def ctc_decoder( ...@@ -251,7 +280,12 @@ def ctc_decoder(
decoding. decoding.
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line format is for tokens mapping to the same index to be on the same line
lm (str or None, optional): file containing language model, or `None` if not using a language model lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
custom language model of type `CTCDecoderLM`, or `None` if not using a language model
lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
(Default: None)
nbest (int, optional): number of best decodings to return (Default: 1) nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50) beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
beam_size_token (int, optional): max number of tokens to consider at each decode step. beam_size_token (int, optional): max number of tokens to consider at each decode step.
...@@ -277,13 +311,14 @@ def ctc_decoder( ...@@ -277,13 +311,14 @@ def ctc_decoder(
>>> ) >>> )
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses >>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
""" """
if lm_dict is not None and type(lm_dict) is not str:
raise ValueError("lm_dict must be None or str type.")
tokens_dict = _Dictionary(tokens) tokens_dict = _Dictionary(tokens)
if lexicon is not None: # decoder options
if lexicon:
lexicon = _load_words(lexicon) lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconDecoderOptions( decoder_options = _LexiconDecoderOptions(
beam_size=beam_size, beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(), beam_size_token=beam_size_token or tokens_dict.index_size(),
...@@ -296,11 +331,6 @@ def ctc_decoder( ...@@ -296,11 +331,6 @@ def ctc_decoder(
criterion_type=_CriterionType.CTC, criterion_type=_CriterionType.CTC,
) )
else: else:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconFreeDecoderOptions( decoder_options = _LexiconFreeDecoderOptions(
beam_size=beam_size, beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(), beam_size_token=beam_size_token or tokens_dict.index_size(),
...@@ -311,6 +341,14 @@ def ctc_decoder( ...@@ -311,6 +341,14 @@ def ctc_decoder(
criterion_type=_CriterionType.CTC, criterion_type=_CriterionType.CTC,
) )
# construct word dict and language model
word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
if type(lm) == str:
lm = _KenLM(lm, word_dict)
elif lm is None:
lm = _ZeroLM()
return CTCDecoder( return CTCDecoder(
nbest=nbest, nbest=nbest,
lexicon=lexicon, lexicon=lexicon,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment