Commit 93024ace authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Move CTC beam search decoder to beta (#2410)

Summary:
Move CTC beam search decoder out of prototype to new `torchaudio.models.decoder` module.

hwangjeff mthrok any thoughts on the new module + naming, and if we should move rnnt beam search here as well??

Pull Request resolved: https://github.com/pytorch/audio/pull/2410

Reviewed By: mthrok

Differential Revision: D36784521

Pulled By: carolineechen

fbshipit-source-id: a2ec52f86bba66e03327a9af0c5df8bbefcd67ed
parent b374cc7b
...@@ -45,6 +45,7 @@ API References ...@@ -45,6 +45,7 @@ API References
transforms transforms
datasets datasets
models models
models.decoder
pipelines pipelines
sox_effects sox_effects
compliance.kaldi compliance.kaldi
......
torchaudio.prototype.ctc_decoder .. role:: hidden
================================ :class: hidden-section
.. currentmodule:: torchaudio.prototype.ctc_decoder torchaudio.models.decoder
=========================
.. currentmodule:: torchaudio.models.decoder
.. py:module:: torchaudio.models.decoder
Decoder Class Decoder Class
------------- -------------
...@@ -16,10 +21,10 @@ CTCDecoder ...@@ -16,10 +21,10 @@ CTCDecoder
.. automethod:: idxs_to_tokens .. automethod:: idxs_to_tokens
Hypothesis CTCHypothesis
~~~~~~~~~~ ~~~~~~~~~~~~~
.. autoclass:: Hypothesis .. autoclass:: CTCHypothesis
Factory Function Factory Function
---------------- ----------------
......
...@@ -14,9 +14,8 @@ imported explicitly, e.g. ...@@ -14,9 +14,8 @@ imported explicitly, e.g.
.. code-block:: python .. code-block:: python
import torchaudio.prototype.ctc_decoder import torchaudio.prototype.models
.. toctree:: .. toctree::
prototype.ctc_decoder
prototype.models prototype.models
prototype.pipelines prototype.pipelines
...@@ -4,7 +4,7 @@ from typing import Optional ...@@ -4,7 +4,7 @@ from typing import Optional
import torch import torch
import torchaudio import torchaudio
from torchaudio.prototype.ctc_decoder import download_pretrained_files, lexicon_decoder from torchaudio.models.decoder import ctc_decoder, download_pretrained_files
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -18,7 +18,7 @@ def run_inference(args): ...@@ -18,7 +18,7 @@ def run_inference(args):
# get decoder files # get decoder files
files = download_pretrained_files("librispeech-4-gram") files = download_pretrained_files("librispeech-4-gram")
decoder = lexicon_decoder( decoder = ctc_decoder(
lexicon=files.lexicon, lexicon=files.lexicon,
tokens=files.tokens, tokens=files.tokens,
lm=files.lm, lm=files.lm,
......
...@@ -73,7 +73,7 @@ import torch ...@@ -73,7 +73,7 @@ import torch
import torchaudio import torchaudio
try: try:
import torchaudio.prototype.ctc_decoder from torchaudio.models.decoder import ctc_decoder
except ModuleNotFoundError: except ModuleNotFoundError:
try: try:
import google.colab import google.colab
...@@ -208,13 +208,13 @@ print(tokens) ...@@ -208,13 +208,13 @@ print(tokens)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# Pretrained files for the LibriSpeech dataset can be downloaded using # Pretrained files for the LibriSpeech dataset can be downloaded using
# :py:func:`download_pretrained_files <torchaudio.prototype.ctc_decoder.download_pretrained_files>`. # :py:func:`download_pretrained_files <torchaudio.models.decoder.download_pretrained_files>`.
# #
# Note: this cell may take a couple of minutes to run, as the language # Note: this cell may take a couple of minutes to run, as the language
# model can be large # model can be large
# #
from torchaudio.prototype.ctc_decoder import download_pretrained_files from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram") files = download_pretrained_files("librispeech-4-gram")
...@@ -233,7 +233,7 @@ print(files) ...@@ -233,7 +233,7 @@ print(files)
# Beam Search Decoder # Beam Search Decoder
# ~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~
# The decoder can be constructed using the factory function # The decoder can be constructed using the factory function
# :py:func:`ctc_decoder <torchaudio.prototype.ctc_decoder.ctc_decoder>`. # :py:func:`ctc_decoder <torchaudio.models.decoder.ctc_decoder>`.
# In addition to the previously mentioned components, it also takes in various beam # In addition to the previously mentioned components, it also takes in various beam
# search decoding parameters and token/word parameters. # search decoding parameters and token/word parameters.
# #
...@@ -241,7 +241,7 @@ print(files) ...@@ -241,7 +241,7 @@ print(files)
# `lm` parameter. # `lm` parameter.
# #
from torchaudio.prototype.ctc_decoder import ctc_decoder from torchaudio.models.decoder import ctc_decoder
LM_WEIGHT = 3.23 LM_WEIGHT = 3.23
WORD_SCORE = -0.26 WORD_SCORE = -0.26
...@@ -295,7 +295,7 @@ greedy_decoder = GreedyCTCDecoder(tokens) ...@@ -295,7 +295,7 @@ greedy_decoder = GreedyCTCDecoder(tokens)
# #
# Now that we have the data, acoustic model, and decoder, we can perform # Now that we have the data, acoustic model, and decoder, we can perform
# inference. The output of the beam search decoder is of type # inference. The output of the beam search decoder is of type
# :py:func:`torchaudio.prototype.ctc_decoder.Hypothesis`, consisting of the # :py:func:`torchaudio.models.decoder.CTCHypothesis`, consisting of the
# predicted token IDs, corresponding words, hypothesis score, and timesteps # predicted token IDs, corresponding words, hypothesis score, and timesteps
# corresponding to the token IDs. Recall the transcript corresponding to the # corresponding to the token IDs. Recall the transcript corresponding to the
# waveform is # waveform is
...@@ -395,7 +395,7 @@ plot_alignments(waveform[0], emission, predicted_tokens, timesteps) ...@@ -395,7 +395,7 @@ plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
# In this section, we go a little bit more in depth about some different # In this section, we go a little bit more in depth about some different
# parameters and tradeoffs. For the full list of customizable parameters, # parameters and tradeoffs. For the full list of customizable parameters,
# please refer to the # please refer to the
# :py:func:`documentation <torchaudio.prototype.ctc_decoder.ctc_decoder>`. # :py:func:`documentation <torchaudio.models.decoder.ctc_decoder>`.
# #
...@@ -419,7 +419,7 @@ def print_decoded(decoder, emission, param, param_value): ...@@ -419,7 +419,7 @@ def print_decoded(decoder, emission, param, param_value):
# nbest # nbest
# ~~~~~ # ~~~~~
# #
# This parameter indicates the number of best Hypothesis to return, which # This parameter indicates the number of best hypotheses to return, which
# is a property that is not possible with the greedy decoder. For # is a property that is not possible with the greedy decoder. For
# instance, by setting ``nbest=3`` when constructing the beam search # instance, by setting ``nbest=3`` when constructing the beam search
# decoder earlier, we can now access the hypotheses with the top 3 scores. # decoder earlier, we can now access the hypotheses with the top 3 scores.
......
...@@ -9,7 +9,7 @@ import pytest ...@@ -9,7 +9,7 @@ import pytest
], ],
) )
def test_decoder_from_pretrained(model, expected, emissions): def test_decoder_from_pretrained(model, expected, emissions):
from torchaudio.prototype.ctc_decoder import ctc_decoder, download_pretrained_files from torchaudio.models.decoder import ctc_decoder, download_pretrained_files
pretrained_files = download_pretrained_files(model) pretrained_files = download_pretrained_files(model)
decoder = ctc_decoder( decoder = ctc_decoder(
......
...@@ -136,7 +136,7 @@ def is_ctc_decoder_available(): ...@@ -136,7 +136,7 @@ def is_ctc_decoder_available():
global _IS_CTC_DECODER_AVAILABLE global _IS_CTC_DECODER_AVAILABLE
if _IS_CTC_DECODER_AVAILABLE is None: if _IS_CTC_DECODER_AVAILABLE is None:
try: try:
from torchaudio.prototype.ctc_decoder import CTCDecoder # noqa: F401 from torchaudio.models.decoder import CTCDecoder # noqa: F401
_IS_CTC_DECODER_AVAILABLE = True _IS_CTC_DECODER_AVAILABLE = True
except Exception: except Exception:
......
...@@ -16,7 +16,7 @@ NUM_TOKENS = 8 ...@@ -16,7 +16,7 @@ NUM_TOKENS = 8
@skipIfNoCtcDecoder @skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase): class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, use_lm=True, use_lexicon=True, **kwargs): def _get_decoder(self, tokens=None, use_lm=True, use_lexicon=True, **kwargs):
from torchaudio.prototype.ctc_decoder import ctc_decoder from torchaudio.models.decoder import ctc_decoder
if use_lexicon: if use_lexicon:
lexicon_file = get_asset_path("decoder/lexicon.txt") lexicon_file = get_asset_path("decoder/lexicon.txt")
......
_INITIALIZED = False
_LAZILY_IMPORTED = [
"CTCHypothesis",
"CTCDecoder",
"ctc_decoder",
"download_pretrained_files",
]
def _init_extension():
import torchaudio
torchaudio._extension._load_lib("libtorchaudio_decoder")
global _INITIALIZED
_INITIALIZED = True
def __getattr__(name: str):
if name in _LAZILY_IMPORTED:
if not _INITIALIZED:
_init_extension()
try:
from . import _ctc_decoder
except AttributeError as err:
raise RuntimeError(
"CTC decoder requires the decoder extension. Please set BUILD_CTC_DECODER=1 when building from source."
) from err
item = getattr(_ctc_decoder, name)
globals()[name] = item
return item
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + _LAZILY_IMPORTED)
__all__ = []
import itertools as it import itertools as it
import warnings
from collections import namedtuple from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Union from typing import Dict, List, NamedTuple, Optional, Union
...@@ -21,13 +20,13 @@ from torchaudio._torchaudio_decoder import ( ...@@ -21,13 +20,13 @@ from torchaudio._torchaudio_decoder import (
) )
from torchaudio.utils import download_asset from torchaudio.utils import download_asset
__all__ = ["Hypothesis", "CTCDecoder", "ctc_decoder", "lexicon_decoder", "download_pretrained_files"] __all__ = ["CTCHypothesis", "CTCDecoder", "ctc_decoder", "download_pretrained_files"]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"]) _PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
class Hypothesis(NamedTuple): class CTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`CTCDecoder`. r"""Represents hypothesis generated by CTC beam search decoder :py:func`CTCDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where :ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where
...@@ -44,8 +43,7 @@ class Hypothesis(NamedTuple): ...@@ -44,8 +43,7 @@ class Hypothesis(NamedTuple):
class CTCDecoder: class CTCDecoder:
"""torchaudio.prototype.ctc_decoder.CTCDecoder() """
.. devices:: CPU .. devices:: CPU
Lexically contrained CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`]. Lexically contrained CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`].
...@@ -128,10 +126,12 @@ class CTCDecoder: ...@@ -128,10 +126,12 @@ class CTCDecoder:
timesteps.append(i) timesteps.append(i)
return torch.IntTensor(timesteps) return torch.IntTensor(timesteps)
def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]: def __call__(
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
) -> List[List[CTCHypothesis]]:
# Overriding the signature so that the return type is correct on Sphinx # Overriding the signature so that the return type is correct on Sphinx
"""__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \ """__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \
List[List[torchaudio.prototype.ctc_decoder.Hypothesis]] List[List[torchaudio.models.decoder.CTCHypothesis]]
Args: Args:
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
...@@ -140,7 +140,7 @@ class CTCDecoder: ...@@ -140,7 +140,7 @@ class CTCDecoder:
in time axis of the output Tensor in each batch. in time axis of the output Tensor in each batch.
Returns: Returns:
List[List[Hypothesis]]: List[List[CTCHypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch. List of sorted best hypotheses for each audio sequence in the batch.
""" """
...@@ -168,7 +168,7 @@ class CTCDecoder: ...@@ -168,7 +168,7 @@ class CTCDecoder:
nbest_results = results[: self.nbest] nbest_results = results[: self.nbest]
hypos.append( hypos.append(
[ [
Hypothesis( CTCHypothesis(
tokens=self._get_tokens(result.tokens), tokens=self._get_tokens(result.tokens),
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0], words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
score=result.score, score=result.score,
...@@ -293,44 +293,6 @@ def ctc_decoder( ...@@ -293,44 +293,6 @@ def ctc_decoder(
) )
def lexicon_decoder(
lexicon: str,
tokens: Union[str, List[str]],
lm: Optional[str] = None,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> CTCDecoder:
warnings.warn("`lexicon_decoder` is now deprecated. Please use `ctc_decoder` instead.")
return ctc_decoder(
lexicon=lexicon,
tokens=tokens,
lm=lm,
nbest=nbest,
beam_size=beam_size,
beam_size_token=beam_size_token,
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
blank_token=blank_token,
sil_token=sil_token,
unk_word=unk_word,
)
def _get_filenames(model: str) -> _PretrainedFiles: def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]: if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError( raise ValueError(
......
_INITIALIZED = False def __getattr__(name: str):
_LAZILY_IMPORTED = [ if name in ["ctc_decoder", "lexicon_decoder"]:
"Hypothesis", import warnings
"CTCDecoder",
"ctc_decoder",
"lexicon_decoder",
"download_pretrained_files",
]
from torchaudio.models.decoder import ctc_decoder
def _init_extension(): warnings.warn(
import torchaudio f"{__name__}.{name} has been moved to torchaudio.models.decoder.ctc_decoder",
DeprecationWarning,
)
torchaudio._extension._load_lib("libtorchaudio_decoder") if name == "lexicon_decoder":
global lexicon_decoder
lexicon_decoder = ctc_decoder
return lexicon_decoder
else:
return ctc_decoder
elif name == "download_pretrained_files":
import warnings
global _INITIALIZED from torchaudio.models.decoder import download_pretrained_files
_INITIALIZED = True
return download_pretrained_files
def __getattr__(name: str):
if name in _LAZILY_IMPORTED:
if not _INITIALIZED:
_init_extension()
try:
from . import _ctc_decoder
except AttributeError as err:
raise RuntimeError(
"CTC decoder requires the decoder extension. Please set BUILD_CTC_DECODER=1 when building from source."
) from err
item = getattr(_ctc_decoder, name)
globals()[name] = item
return item
raise AttributeError(f"module {__name__} has no attribute {name}") raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__(): def __dir__():
return sorted(__all__ + _LAZILY_IMPORTED) return ["ctc_decoder", "lexicon_decoder", "download_pretrained_files"]
__all__ = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment