Commit 03a0d68e authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add NNLM support to CTC Decoder (#2528)

Summary:
Expose flashlight's LM and LMState classes to support decoding with custom language models, including NN LMs.

The `ctc_decoder` API is as follows
- To decode with KenLM, pass in KenLM language model path to `lm` variable
- To decode with custom LM, create Python class with `CTCDecoderLM` subclass, and pass in the class to `lm` variable. Additionally create a file of LM words listed in order of the LM index, with a word per line, and pass in the file to `lm_path`.
- To decode without a language model, set `lm` to `None` (default)

Validated against fairseq w2l decoder on sample LibriSpeech dataset and LM. Code for validation can be found [here](https://github.com/facebookresearch/fairseq/compare/main...carolineechen:fairseq:ctc-decoder). Also added unit tests to validate custom implementations of ZeroLM and KenLM, and also using a biased LM.

Follow ups:
- Train simple LM on LibriSpeech and demonstrate usage in tutorial or examples directory

cc jacobkahn

Pull Request resolved: https://github.com/pytorch/audio/pull/2528

Reviewed By: mthrok

Differential Revision: D38243802

Pulled By: carolineechen

fbshipit-source-id: 445e78f6c20bda655aabf819fc0f771fe68c73d7
parent c15eee23
......@@ -9,23 +9,31 @@ NUM_TOKENS = 8
@skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, use_lm=True, use_lexicon=True, **kwargs):
from torchaudio.models.decoder import ctc_decoder
def _get_custom_kenlm(self, kenlm_file):
from .ctc_decoder_utils import CustomKenLM
if use_lexicon:
lexicon_file = get_asset_path("decoder/lexicon.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa") if use_lm else None
else:
lexicon_file = None
kenlm_file = get_asset_path("decoder/kenlm_char.arpa") if use_lm else None
dict_file = get_asset_path("decoder/lexicon.txt")
custom_lm = CustomKenLM(kenlm_file, dict_file)
return custom_lm
def _get_biased_nnlm(self, dict_file, keyword):
from .ctc_decoder_utils import BiasedLM, CustomBiasedLM
model = BiasedLM(dict_file, keyword)
biased_lm = CustomBiasedLM(model, dict_file)
return biased_lm
def _get_decoder(self, tokens=None, lm=None, use_lexicon=True, **kwargs):
from torchaudio.models.decoder import ctc_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt") if use_lexicon else None
if tokens is None:
tokens = get_asset_path("decoder/tokens.txt")
return ctc_decoder(
lexicon=lexicon_file,
tokens=tokens,
lm=kenlm_file,
lm=lm,
**kwargs,
)
......@@ -40,13 +48,13 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
list(
itertools.product(
[get_asset_path("decoder/tokens.txt"), ["-", "|", "f", "o", "b", "a", "r"]],
[True, False],
[None, get_asset_path("decoder/kenlm.arpa")],
[True, False],
)
),
)
def test_construct_decoder(self, tokens, use_lm, use_lexicon):
self._get_decoder(tokens=tokens, use_lm=use_lm, use_lexicon=use_lexicon)
def test_construct_basic_decoder(self, tokens, lm, use_lexicon):
self._get_decoder(tokens=tokens, lm=lm, use_lexicon=use_lexicon)
@parameterized.expand(
[(True,), (False,)],
......@@ -72,14 +80,63 @@ class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
self.assertEqual(result.tokens.shape, result.timesteps.shape)
def test_no_lm_decoder(self):
"""Check that using no LM produces the same result as using an LM with 0 lm_weight"""
kenlm_decoder = self._get_decoder(lm_weight=0)
zerolm_decoder = self._get_decoder(use_lm=False)
"""Check that the following produce the same result
- using no LM (C++ based implementation)
- using no LM (Custom Python based wrapper)
- using a (Ken)LM with 0 weight
"""
from .ctc_decoder_utils import CustomZeroLM
emissions = self._get_emissions()
custom_zerolm = CustomZeroLM()
zerolm_decoder_custom = self._get_decoder(lm=custom_zerolm)
zerolm_decoder_cpp = self._get_decoder(lm=None)
kenlm_file = get_asset_path("decoder/kenlm.arpa")
kenlm_decoder = self._get_decoder(lm=kenlm_file, lm_weight=0)
zerolm_custom_results = zerolm_decoder_custom(emissions)
zerolm_cpp_results = zerolm_decoder_cpp(emissions)
kenlm_results = kenlm_decoder(emissions)
zerolm_results = zerolm_decoder(emissions)
self.assertEqual(kenlm_results, zerolm_results)
self.assertEqual(zerolm_cpp_results, zerolm_custom_results)
self.assertEqual(zerolm_cpp_results, kenlm_results)
def test_custom_kenlm_decoder(self):
"""Check that creating a custom Python KenLM wrapper produces same results as C++ based KenLM"""
emissions = self._get_emissions()
kenlm_file = get_asset_path("decoder/kenlm.arpa")
custom_kenlm = self._get_custom_kenlm(kenlm_file)
kenlm_decoder_custom = self._get_decoder(lm=custom_kenlm)
kenlm_decoder_cpp = self._get_decoder(lm=kenlm_file)
kenlm_custom_results = kenlm_decoder_custom(emissions)
kenlm_cpp_results = kenlm_decoder_cpp(emissions)
self.assertEqual(kenlm_custom_results, kenlm_cpp_results)
@parameterized.expand(
[
(get_asset_path("decoder/nnlm_lex_dict.txt"), "foo", True),
(get_asset_path("decoder/nnlm_lexfree_dict.txt"), "f", False),
]
)
def test_custom_nnlm_decoder(self, lm_dict, keyword, use_lexicon):
"""Check that biased NNLM only produces biased words"""
emissions = self._get_emissions()
custom_nnlm = self._get_biased_nnlm(lm_dict, keyword)
nnlm_decoder = self._get_decoder(lm=custom_nnlm, lm_dict=lm_dict, use_lexicon=use_lexicon, lm_weight=10)
nnlm_results = nnlm_decoder(emissions)
if use_lexicon:
output = [result[0].words for result in nnlm_results]
else:
tokens = [nnlm_decoder.idxs_to_tokens(result[0].tokens) for result in nnlm_results]
output = [list(filter(("|").__ne__, t)) for t in tokens] # filter out silence characters
lens = [len(out) for out in output]
expected = [[keyword] * len for len in lens] # all of output should match the biased keyword
assert expected == output
def test_get_timesteps(self):
unprocessed_tokens = torch.tensor([2, 2, 0, 3, 3, 3, 0, 3])
......
import torch
from torchaudio.models.decoder import CTCDecoderLM, CTCDecoderLMState
from torchaudio.models.decoder._ctc_decoder import _create_word_dict, _Dictionary, _KenLM, _load_words
class CustomZeroLM(CTCDecoderLM):
def __init__(self):
CTCDecoderLM.__init__(self)
def start(self, start_with_nothing: bool):
return CTCDecoderLMState()
def score(self, state: CTCDecoderLMState, token_index: int):
return (state.child(token_index), 0.0)
def finish(self, state: CTCDecoderLMState):
return (state, 0.0)
class CustomKenLM(CTCDecoderLM):
def __init__(self, kenlm_file, dict_file):
CTCDecoderLM.__init__(self)
kenlm_dict = _create_word_dict(_load_words(dict_file))
self.model = _KenLM(kenlm_file, kenlm_dict)
def start(self, start_with_nothing: bool):
return self.model.start(start_with_nothing)
def score(self, state: CTCDecoderLMState, token_index: int):
return self.model.score(state, token_index)
def finish(self, state: CTCDecoderLMState):
return self.model.finish(state)
class BiasedLM(torch.nn.Module):
def __init__(self, dict_file, keyword):
super(BiasedLM, self).__init__()
self.dictionary = _Dictionary(dict_file)
self.keyword = keyword
def forward(self, token_idx):
if self.dictionary.get_entry(token_idx) == self.keyword:
return torch.tensor(10)
elif self.dictionary.get_entry(token_idx) == "<unk>":
return torch.tensor(-torch.inf)
return torch.tensor(-10)
class CustomBiasedLM(CTCDecoderLM):
def __init__(self, model, dict_file):
CTCDecoderLM.__init__(self)
self.model = model
self.vocab = _Dictionary(dict_file)
self.eos = self.vocab.get_index("|")
self.states = {}
model.eval()
def start(self, start_with_nothing: bool = False):
state = CTCDecoderLMState()
with torch.no_grad():
score = self.model(self.eos)
self.states[state] = score
return state
def score(self, state: CTCDecoderLMState, token_index: int):
outstate = state.child(token_index)
if outstate not in self.states:
score = self.model(token_index)
self.states[outstate] = score
score = self.states[outstate]
return outstate, score
def finish(self, state: CTCDecoderLMState):
return self.score(state, self.eos)
......@@ -2,6 +2,8 @@ _INITIALIZED = False
_LAZILY_IMPORTED = [
"CTCHypothesis",
"CTCDecoder",
"CTCDecoderLM",
"CTCDecoderLMState",
"ctc_decoder",
"download_pretrained_files",
]
......
......@@ -17,7 +17,8 @@ try:
LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as _LM,
LM as CTCDecoderLM,
LMState as CTCDecoderLMState,
SmearingMode as _SmearingMode,
Trie as _Trie,
ZeroLM as _ZeroLM,
......@@ -36,7 +37,8 @@ except Exception:
LexiconDecoderOptions as _LexiconDecoderOptions,
LexiconFreeDecoder as _LexiconFreeDecoder,
LexiconFreeDecoderOptions as _LexiconFreeDecoderOptions,
LM as _LM,
LM as CTCDecoderLM,
LMState as CTCDecoderLMState,
SmearingMode as _SmearingMode,
Trie as _Trie,
ZeroLM as _ZeroLM,
......@@ -48,12 +50,48 @@ except Exception:
)
__all__ = ["CTCHypothesis", "CTCDecoder", "ctc_decoder", "download_pretrained_files"]
__all__ = [
"CTCHypothesis",
"CTCDecoder",
"CTCDecoderLM",
"CTCDecoderLMState",
"ctc_decoder",
"download_pretrained_files",
]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
def _construct_trie(tokens_dict, word_dict, lexicon, lm, silence):
vocab_size = tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
return trie
def _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word):
word_dict = None
if lm_dict is not None:
word_dict = _Dictionary(lm_dict)
if lexicon and word_dict is None:
word_dict = _create_word_dict(lexicon)
elif not lexicon and word_dict is None and type(lm) == str:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
return word_dict
class CTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func:`CTCDecoder`.
......@@ -89,7 +127,7 @@ class CTCDecoder:
lexicon (Dict or None): lexicon mapping of words to spellings, or None for lexicon-free decoder
word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens
lm (_LM): language model
lm (CTCDecoderLM): language model. If using a lexicon, only word level LMs are currently supported
decoder_options (_LexiconDecoderOptions or _LexiconFreeDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence
......@@ -102,7 +140,7 @@ class CTCDecoder:
lexicon: Optional[Dict],
word_dict: _Dictionary,
tokens_dict: _Dictionary,
lm: _LM,
lm: CTCDecoderLM,
decoder_options: Union[_LexiconDecoderOptions, _LexiconFreeDecoderOptions],
blank_token: str,
sil_token: str,
......@@ -113,21 +151,12 @@ class CTCDecoder:
self.tokens_dict = tokens_dict
self.blank = self.tokens_dict.get_index(blank_token)
silence = self.tokens_dict.get_index(sil_token)
transitions = []
if lexicon:
trie = _construct_trie(tokens_dict, word_dict, lexicon, lm, silence)
unk_word = word_dict.get_index(unk_word)
vocab_size = self.tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = lm.start(False)
for word, spellings in lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = lm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
token_lm = False # use word level LM
self.decoder = _LexiconDecoder(
decoder_options,
......@@ -136,11 +165,11 @@ class CTCDecoder:
silence,
self.blank,
unk_word,
[],
False, # word level LM
transitions,
token_lm,
)
else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, [])
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
......@@ -194,7 +223,6 @@ class CTCDecoder:
for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, lengths[b], N)
nbest_results = results[: self.nbest]
......@@ -228,7 +256,8 @@ class CTCDecoder:
def ctc_decoder(
lexicon: Optional[str],
tokens: Union[str, List[str]],
lm: Optional[str] = None,
lm: Union[str, CTCDecoderLM] = None,
lm_dict: Optional[str] = None,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
......@@ -251,7 +280,12 @@ def ctc_decoder(
decoding.
tokens (str or List[str]): file or list containing valid tokens. If using a file, the expected
format is for tokens mapping to the same index to be on the same line
lm (str or None, optional): file containing language model, or `None` if not using a language model
lm (str, CTCDecoderLM, or None, optional): either a path containing KenLM language model,
custom language model of type `CTCDecoderLM`, or `None` if not using a language model
lm_dict (str or None, optional): file consisting of the dictionary used for the LM, with a word
per line sorted by LM index. If decoding with a lexicon, entries in lm_dict must also occur
in the lexicon file. If `None`, dictionary for LM is constructed using the lexicon file.
(Default: None)
nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
beam_size_token (int, optional): max number of tokens to consider at each decode step.
......@@ -277,13 +311,14 @@ def ctc_decoder(
>>> )
>>> results = decoder(emissions) # List of shape (B, nbest) of Hypotheses
"""
if lm_dict is not None and type(lm_dict) is not str:
raise ValueError("lm_dict must be None or str type.")
tokens_dict = _Dictionary(tokens)
if lexicon is not None:
# decoder options
if lexicon:
lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
......@@ -296,11 +331,6 @@ def ctc_decoder(
criterion_type=_CriterionType.CTC,
)
else:
d = {tokens_dict.get_entry(i): [[tokens_dict.get_entry(i)]] for i in range(tokens_dict.index_size())}
d[unk_word] = [[unk_word]]
word_dict = _create_word_dict(d)
lm = _KenLM(lm, word_dict) if lm else _ZeroLM()
decoder_options = _LexiconFreeDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
......@@ -311,6 +341,14 @@ def ctc_decoder(
criterion_type=_CriterionType.CTC,
)
# construct word dict and language model
word_dict = _get_word_dict(lexicon, lm, lm_dict, tokens_dict, unk_word)
if type(lm) == str:
lm = _KenLM(lm, word_dict)
elif lm is None:
lm = _ZeroLM()
return CTCDecoder(
nbest=nbest,
lexicon=lexicon,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment