"docs/vscode:/vscode.git/clone" did not exist on "4db08045baa250148b1e176e9ac1d5797affcd75"
Commit a76b0066 authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Add Python CTC decoder API (#2089)

Summary:
Part of https://github.com/pytorch/audio/issues/2072 -- splitting up PR for easier review

This PR adds Python decoder API and basic README

Pull Request resolved: https://github.com/pytorch/audio/pull/2089

Reviewed By: mthrok

Differential Revision: D33299818

Pulled By: carolineechen

fbshipit-source-id: 778ec3692331e95258d3734f0d4ab60b6618ddbc
parent 5859923a
...@@ -73,6 +73,26 @@ Hypothesis ...@@ -73,6 +73,26 @@ Hypothesis
.. autoclass:: Hypothesis .. autoclass:: Hypothesis
KenLMLexiconDecoder
~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torchaudio.prototype.ctc_decoder
.. autoclass:: KenLMLexiconDecoder
.. automethod:: __call__
.. automethod:: idxs_to_tokens
kenlm_lexicon_decoder
~~~~~~~~~~~~~~~~~~~~~
.. currentmodule:: torchaudio.prototype.ctc_decoder
.. autoclass:: kenlm_lexicon_decoder
References References
~~~~~~~~~~ ~~~~~~~~~~
......
\data\
ngram 1=6
ngram 2=9
ngram 3=8
\1-grams:
-0.8515802 <unk> 0
0 <s> -0.30103
-0.8515802 </s> 0
-0.8515802 foo -0.30103
-0.44013768 bar -0.30103
-0.6679358 foobar -0.30103
\2-grams:
-0.7091413 foo </s> 0
-0.6251838 bar </s> 0
-0.24384303 foobar </s> 0
-0.6251838 <s> foo -0.30103
-0.49434766 foo foo -0.30103
-0.39393726 bar foo -0.30103
-0.4582359 <s> bar -0.30103
-0.51359576 foo bar -0.30103
-0.56213206 <s> foobar -0.30103
\3-grams:
-0.45881382 bar foo </s>
-0.43354067 foo bar </s>
-0.105027884 <s> foobar </s>
-0.18033421 <s> foo foo
-0.38702002 bar foo foo
-0.15375455 <s> bar foo
-0.34500393 foo bar foo
-0.18492673 foo foo bar
\end\
foo f o o |
bar b a r |
foobar f o o b a r |
...@@ -7,6 +7,7 @@ from .case_utils import ( ...@@ -7,6 +7,7 @@ from .case_utils import (
TestBaseMixin, TestBaseMixin,
PytorchTestCase, PytorchTestCase,
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoCtcDecoder,
skipIfNoCuda, skipIfNoCuda,
skipIfNoExec, skipIfNoExec,
skipIfNoModule, skipIfNoModule,
...@@ -42,6 +43,7 @@ __all__ = [ ...@@ -42,6 +43,7 @@ __all__ = [
"TestBaseMixin", "TestBaseMixin",
"PytorchTestCase", "PytorchTestCase",
"TorchaudioTestCase", "TorchaudioTestCase",
"skipIfNoCtcDecoder",
"skipIfNoCuda", "skipIfNoCuda",
"skipIfNoExec", "skipIfNoExec",
"skipIfNoModule", "skipIfNoModule",
......
...@@ -10,6 +10,7 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase ...@@ -10,6 +10,7 @@ from torch.testing._internal.common_utils import TestCase as PytorchTestCase
from torchaudio._internal.module_utils import is_module_available, is_sox_available, is_kaldi_available from torchaudio._internal.module_utils import is_module_available, is_sox_available, is_kaldi_available
from .backend_utils import set_audio_backend from .backend_utils import set_audio_backend
from .ctc_decoder_utils import is_ctc_decoder_available
class TempDirMixin: class TempDirMixin:
...@@ -115,6 +116,7 @@ def skipIfNoCuda(test_item): ...@@ -115,6 +116,7 @@ def skipIfNoCuda(test_item):
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason="Sox not available") skipIfNoSox = unittest.skipIf(not is_sox_available(), reason="Sox not available")
skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason="Kaldi not available") skipIfNoKaldi = unittest.skipIf(not is_kaldi_available(), reason="Kaldi not available")
skipIfNoCtcDecoder = unittest.skipIf(not is_ctc_decoder_available(), reason="CTC decoder not available")
skipIfRocm = unittest.skipIf( skipIfRocm = unittest.skipIf(
os.getenv("TORCHAUDIO_TEST_WITH_ROCM", "0") == "1", reason="test doesn't currently work on the ROCm stack" os.getenv("TORCHAUDIO_TEST_WITH_ROCM", "0") == "1", reason="test doesn't currently work on the ROCm stack"
) )
......
def is_ctc_decoder_available():
try:
import torchaudio.prototype.ctc_decoder # noqa: F401
return True
except ImportError:
return False
import torch
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
get_asset_path,
skipIfNoCtcDecoder,
)
@skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self):
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder
lexicon_file = get_asset_path("decoder/lexicon.txt")
tokens_file = get_asset_path("decoder/tokens.txt")
kenlm_file = get_asset_path("decoder/kenlm.arpa")
return kenlm_lexicon_decoder(
lexicon=lexicon_file,
tokens=tokens_file,
kenlm=kenlm_file,
)
def test_construct_decoder(self):
self._get_decoder()
def test_shape(self):
B, T, N = 4, 15, 10
torch.manual_seed(0)
emissions = torch.rand(B, T, N)
decoder = self._get_decoder()
results = decoder(emissions)
self.assertEqual(len(results), B)
def test_index_to_tokens(self):
# decoder tokens: '-' '|' 'f' 'o' 'b' 'a' 'r'
decoder = self._get_decoder()
idxs = torch.LongTensor((1, 2, 1, 3, 5))
tokens = decoder.idxs_to_tokens(idxs)
expected_tokens = ["|", "f", "|", "o", "a"]
self.assertEqual(tokens, expected_tokens)
# Flashlight Decoder Binding
CTC Decoder with KenLM and lexicon support based on [flashlight](https://github.com/flashlight/flashlight) decoder implementation
and fairseq [KenLMDecoder](https://github.com/pytorch/fairseq/blob/fcca32258c8e8bcc9f9890bf4714fa2f96b6b3e1/examples/speech_recognition/new/decoders/flashlight_decoder.py#L53)
Python wrapper
## Setup
### Build torchaudio with decoder support
```
BUILD_CTC_DECODER=1 python setup.py develop
```
## Usage
```py
from torchaudio.prototype.ctc_decoder import kenlm_lexicon_decoder
decoder = kenlm_lexicon_decoder(args...)
results = decoder(emissions) # dim (B, nbest) of dictionary of "tokens", "score", "words" keys
best_transcripts = [" ".join(results[i][0].words).strip() for i in range(B)]
```
## Required Files
- tokens: tokens for which the acoustic model generates probabilities for
- lexicon: mapping between words and its corresponding spelling
- language model: n-gram KenLM model
## Experiment Results
LibriSpeech dev-other and test-other results using pretrained [Wav2Vec2](https://arxiv.org/pdf/2006.11477.pdf) models of
BASE configuration.
| Model | Decoder | dev-other | test-other | beam search params |
| ----------- | ---------- | ----------- | ---------- |-------------------------------------------- |
| BASE_10M | Greedy | 51.6 | 51 | |
| | 4-gram LM | 15.95 | 15.9 | LM weight=3.23, word score=-0.26, beam=1500 |
| BASE_100H | Greedy | 13.6 | 13.3 | |
| | 4-gram LM | 8.5 | 8.8 | LM weight=2.15, word score=-0.52, beam=50 |
| BASE_960H | Greedy | 8.9 | 8.4 | |
| | 4-gram LM | 6.3 | 6.4 | LM weight=1.74, word score=0.52, beam=50 |
import torchaudio
try:
torchaudio._extension._load_lib("libtorchaudio_decoder")
from .ctc_decoder import KenLMLexiconDecoder, kenlm_lexicon_decoder
except ImportError as err:
raise ImportError(
"flashlight decoder bindings are required to use this functionality. "
"Please set BUILD_CTC_DECODER=1 when building from source."
) from err
__all__ = [
"KenLMLexiconDecoder",
"kenlm_lexicon_decoder",
]
import itertools as it
from collections import namedtuple
from typing import List, Optional
import torch
from torchaudio._torchaudio_decoder import (
_CriterionType,
_KenLM,
_LexiconDecoder,
_LexiconDecoderOptions,
_SmearingMode,
_Trie,
_Dictionary,
_create_word_dict,
_load_words,
)
from typing import Dict
__all__ = ["KenLMLexiconDecoder", "kenlm_lexicon_decoder"]
Hypothesis = namedtuple("Hypothesis", ["tokens", "words", "score"])
class KenLMLexiconDecoder:
def __init__(
self,
nbest: int,
lexicon: Dict,
word_dict: _Dictionary,
tokens_dict: _Dictionary,
kenlm: _KenLM,
decoder_options: _LexiconDecoderOptions,
blank_token: str,
sil_token: str,
unk_word: str,
) -> None:
"""
KenLM CTC Decoder with Lexicon constraint.
Note:
To build the decoder, please use the factory function kenlm_lexicon_decoder.
Args:
nbest (int): number of best decodings to return
lexicon (Dict): lexicon mapping of words to spellings
word_dict (_Dictionary): dictionary of words
tokens_dict (_Dictionary): dictionary of tokens
kenlm (_KenLM): n-gram KenLM language model
decoder_options (_LexiconDecoderOptions): parameters used for beam search decoding
blank_token (str): token corresopnding to blank
sil_token (str): token corresponding to silence
unk_word (str): word corresponding to unknown
"""
self.nbest = nbest
self.word_dict = word_dict
self.tokens_dict = tokens_dict
unk_word = word_dict.get_index(unk_word)
self.blank = self.tokens_dict.get_index(blank_token)
silence = self.tokens_dict.get_index(sil_token)
vocab_size = self.tokens_dict.index_size()
trie = _Trie(vocab_size, silence)
start_state = kenlm.start(False)
for word, spellings in lexicon.items():
word_idx = self.word_dict.get_index(word)
_, score = kenlm.score(start_state, word_idx)
for spelling in spellings:
spelling_idx = [self.tokens_dict.get_index(token) for token in spelling]
trie.insert(spelling_idx, word_idx, score)
trie.smear(_SmearingMode.MAX)
self.decoder = _LexiconDecoder(
decoder_options,
trie,
kenlm,
silence,
self.blank,
unk_word,
[],
False, # word level LM
)
def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
idxs = filter(lambda x: x != self.blank, idxs)
return torch.LongTensor(list(idxs))
def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]:
"""
Args:
emissions (torch.FloatTensor): tensor of shape `(batch, frame, num_tokens)` storing sequences of
probability distribution over labels; output of acoustic model
lengths (Tensor or None, optional): tensor of shape `(batch, )` storing the valid length of
in time axis of the output Tensor in each batch
Returns:
List[List[Hypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
Each hypothesis is named tuple with the following fields:
tokens: torch.LongTensor of raw token IDs
score: hypothesis score
words: list of decoded words
"""
assert emissions.dtype == torch.float32
B, T, N = emissions.size()
if lengths is None:
lengths = torch.full((B,), T)
float_bytes = 4
hypos = []
for b in range(B):
emissions_ptr = emissions.data_ptr() + float_bytes * b * emissions.stride(0)
results = self.decoder.decode(emissions_ptr, lengths[b], N)
nbest_results = results[: self.nbest]
hypos.append(
[
Hypothesis(
self._get_tokens(result.tokens), # token ids
[self.word_dict.get_entry(x) for x in result.words if x >= 0], # words
result.score, # score
)
for result in nbest_results
]
)
return hypos
def idxs_to_tokens(self, idxs: torch.LongTensor) -> List:
"""
Map raw token IDs into corresponding tokens
Args:
idxs (LongTensor): raw token IDs generated from decoder
Returns:
List: tokens corresponding to the input IDs
"""
return [self.tokens_dict.get_entry(idx.item()) for idx in idxs]
def kenlm_lexicon_decoder(
lexicon: str,
tokens: str,
kenlm: str,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> KenLMLexiconDecoder:
"""
Builds Ken LM CTC Lexicon Decoder with given parameters
Args:
lexicon (str): lexicon file containing the possible words
tokens (str): file containing valid tokens
kenlm (str): file containing languge model
nbest (int, optional): number of best decodings to return (Default: 1)
beam_size (int, optional): max number of hypos to hold after each decode step (Default: 50)
beam_size_token (int, optional): max number of tokens to consider at each decode step.
If None, it is set to the total number of tokens (Default: None)
beam_threshold (float, optional): threshold for pruning hypothesis (Default: 50)
lm_weight (float, optional): weight of lm (Default: 2)
word_score (float, optional): word insertion score (Default: 0)
unk_score (float, optional): unknown word insertion score (Default: -inf)
sil_score (float, optional): silence insertion score (Default: 0)
log_add (bool, optional): whether or not to use logadd when merging hypotheses (Default: False)
blank_token (str, optional): token corresponding to blank (Default: "-")
sil_token (str, optional): token corresponding to silence (Default: "|")
unk_word (str, optional): word corresponding to unknown (Default: "<unk>")
Returns:
KenLMLexiconDecoder: decoder
"""
lexicon = _load_words(lexicon)
word_dict = _create_word_dict(lexicon)
kenlm = _KenLM(kenlm, word_dict)
tokens_dict = _Dictionary(tokens)
decoder_options = _LexiconDecoderOptions(
beam_size=beam_size,
beam_size_token=beam_size_token or tokens_dict.index_size(),
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
criterion_type=_CriterionType.CTC,
)
return KenLMLexiconDecoder(
nbest=nbest,
lexicon=lexicon,
word_dict=word_dict,
tokens_dict=tokens_dict,
kenlm=kenlm,
decoder_options=decoder_options,
blank_token=blank_token,
sil_token=sil_token,
unk_word=unk_word,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment