Commit 93024ace authored by Caroline Chen's avatar Caroline Chen Committed by Facebook GitHub Bot
Browse files

Move CTC beam search decoder to beta (#2410)

Summary:
Move CTC beam search decoder out of prototype to new `torchaudio.models.decoder` module.

hwangjeff mthrok any thoughts on the new module + naming, and if we should move rnnt beam search here as well??

Pull Request resolved: https://github.com/pytorch/audio/pull/2410

Reviewed By: mthrok

Differential Revision: D36784521

Pulled By: carolineechen

fbshipit-source-id: a2ec52f86bba66e03327a9af0c5df8bbefcd67ed
parent b374cc7b
......@@ -45,6 +45,7 @@ API References
transforms
datasets
models
models.decoder
pipelines
sox_effects
compliance.kaldi
......
torchaudio.prototype.ctc_decoder
================================
.. role:: hidden
:class: hidden-section
.. currentmodule:: torchaudio.prototype.ctc_decoder
torchaudio.models.decoder
=========================
.. currentmodule:: torchaudio.models.decoder
.. py:module:: torchaudio.models.decoder
Decoder Class
-------------
......@@ -16,10 +21,10 @@ CTCDecoder
.. automethod:: idxs_to_tokens
Hypothesis
~~~~~~~~~~
CTCHypothesis
~~~~~~~~~~~~~
.. autoclass:: Hypothesis
.. autoclass:: CTCHypothesis
Factory Function
----------------
......
......@@ -14,9 +14,8 @@ imported explicitly, e.g.
.. code-block:: python
import torchaudio.prototype.ctc_decoder
import torchaudio.prototype.models
.. toctree::
prototype.ctc_decoder
prototype.models
prototype.pipelines
......@@ -4,7 +4,7 @@ from typing import Optional
import torch
import torchaudio
from torchaudio.prototype.ctc_decoder import download_pretrained_files, lexicon_decoder
from torchaudio.models.decoder import ctc_decoder, download_pretrained_files
logger = logging.getLogger(__name__)
......@@ -18,7 +18,7 @@ def run_inference(args):
# get decoder files
files = download_pretrained_files("librispeech-4-gram")
decoder = lexicon_decoder(
decoder = ctc_decoder(
lexicon=files.lexicon,
tokens=files.tokens,
lm=files.lm,
......
......@@ -73,7 +73,7 @@ import torch
import torchaudio
try:
import torchaudio.prototype.ctc_decoder
from torchaudio.models.decoder import ctc_decoder
except ModuleNotFoundError:
try:
import google.colab
......@@ -208,13 +208,13 @@ print(tokens)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Pretrained files for the LibriSpeech dataset can be downloaded using
# :py:func:`download_pretrained_files <torchaudio.prototype.ctc_decoder.download_pretrained_files>`.
# :py:func:`download_pretrained_files <torchaudio.models.decoder.download_pretrained_files>`.
#
# Note: this cell may take a couple of minutes to run, as the language
# model can be large
#
from torchaudio.prototype.ctc_decoder import download_pretrained_files
from torchaudio.models.decoder import download_pretrained_files
files = download_pretrained_files("librispeech-4-gram")
......@@ -233,7 +233,7 @@ print(files)
# Beam Search Decoder
# ~~~~~~~~~~~~~~~~~~~
# The decoder can be constructed using the factory function
# :py:func:`ctc_decoder <torchaudio.prototype.ctc_decoder.ctc_decoder>`.
# :py:func:`ctc_decoder <torchaudio.models.decoder.ctc_decoder>`.
# In addition to the previously mentioned components, it also takes in various beam
# search decoding parameters and token/word parameters.
#
......@@ -241,7 +241,7 @@ print(files)
# `lm` parameter.
#
from torchaudio.prototype.ctc_decoder import ctc_decoder
from torchaudio.models.decoder import ctc_decoder
LM_WEIGHT = 3.23
WORD_SCORE = -0.26
......@@ -295,7 +295,7 @@ greedy_decoder = GreedyCTCDecoder(tokens)
#
# Now that we have the data, acoustic model, and decoder, we can perform
# inference. The output of the beam search decoder is of type
# :py:func:`torchaudio.prototype.ctc_decoder.Hypothesis`, consisting of the
# :py:func:`torchaudio.models.decoder.CTCHypothesis`, consisting of the
# predicted token IDs, corresponding words, hypothesis score, and timesteps
# corresponding to the token IDs. Recall the transcript corresponding to the
# waveform is
......@@ -395,7 +395,7 @@ plot_alignments(waveform[0], emission, predicted_tokens, timesteps)
# In this section, we go a little bit more in depth about some different
# parameters and tradeoffs. For the full list of customizable parameters,
# please refer to the
# :py:func:`documentation <torchaudio.prototype.ctc_decoder.ctc_decoder>`.
# :py:func:`documentation <torchaudio.models.decoder.ctc_decoder>`.
#
......@@ -419,7 +419,7 @@ def print_decoded(decoder, emission, param, param_value):
# nbest
# ~~~~~
#
# This parameter indicates the number of best Hypothesis to return, which
# This parameter indicates the number of best hypotheses to return, which
# is a property that is not possible with the greedy decoder. For
# instance, by setting ``nbest=3`` when constructing the beam search
# decoder earlier, we can now access the hypotheses with the top 3 scores.
......
......@@ -9,7 +9,7 @@ import pytest
],
)
def test_decoder_from_pretrained(model, expected, emissions):
from torchaudio.prototype.ctc_decoder import ctc_decoder, download_pretrained_files
from torchaudio.models.decoder import ctc_decoder, download_pretrained_files
pretrained_files = download_pretrained_files(model)
decoder = ctc_decoder(
......
......@@ -136,7 +136,7 @@ def is_ctc_decoder_available():
global _IS_CTC_DECODER_AVAILABLE
if _IS_CTC_DECODER_AVAILABLE is None:
try:
from torchaudio.prototype.ctc_decoder import CTCDecoder # noqa: F401
from torchaudio.models.decoder import CTCDecoder # noqa: F401
_IS_CTC_DECODER_AVAILABLE = True
except Exception:
......
......@@ -16,7 +16,7 @@ NUM_TOKENS = 8
@skipIfNoCtcDecoder
class CTCDecoderTest(TempDirMixin, TorchaudioTestCase):
def _get_decoder(self, tokens=None, use_lm=True, use_lexicon=True, **kwargs):
from torchaudio.prototype.ctc_decoder import ctc_decoder
from torchaudio.models.decoder import ctc_decoder
if use_lexicon:
lexicon_file = get_asset_path("decoder/lexicon.txt")
......
_INITIALIZED = False
_LAZILY_IMPORTED = [
"CTCHypothesis",
"CTCDecoder",
"ctc_decoder",
"download_pretrained_files",
]
def _init_extension():
import torchaudio
torchaudio._extension._load_lib("libtorchaudio_decoder")
global _INITIALIZED
_INITIALIZED = True
def __getattr__(name: str):
if name in _LAZILY_IMPORTED:
if not _INITIALIZED:
_init_extension()
try:
from . import _ctc_decoder
except AttributeError as err:
raise RuntimeError(
"CTC decoder requires the decoder extension. Please set BUILD_CTC_DECODER=1 when building from source."
) from err
item = getattr(_ctc_decoder, name)
globals()[name] = item
return item
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + _LAZILY_IMPORTED)
__all__ = []
import itertools as it
import warnings
from collections import namedtuple
from typing import Dict, List, NamedTuple, Optional, Union
......@@ -21,13 +20,13 @@ from torchaudio._torchaudio_decoder import (
)
from torchaudio.utils import download_asset
__all__ = ["Hypothesis", "CTCDecoder", "ctc_decoder", "lexicon_decoder", "download_pretrained_files"]
__all__ = ["CTCHypothesis", "CTCDecoder", "ctc_decoder", "download_pretrained_files"]
_PretrainedFiles = namedtuple("PretrainedFiles", ["lexicon", "tokens", "lm"])
class Hypothesis(NamedTuple):
class CTCHypothesis(NamedTuple):
r"""Represents hypothesis generated by CTC beam search decoder :py:func`CTCDecoder`.
:ivar torch.LongTensor tokens: Predicted sequence of token IDs. Shape `(L, )`, where
......@@ -44,8 +43,7 @@ class Hypothesis(NamedTuple):
class CTCDecoder:
"""torchaudio.prototype.ctc_decoder.CTCDecoder()
"""
.. devices:: CPU
Lexically contrained CTC beam search decoder from *Flashlight* [:footcite:`kahn2022flashlight`].
......@@ -128,10 +126,12 @@ class CTCDecoder:
timesteps.append(i)
return torch.IntTensor(timesteps)
def __call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> List[List[Hypothesis]]:
def __call__(
self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None
) -> List[List[CTCHypothesis]]:
# Overriding the signature so that the return type is correct on Sphinx
"""__call__(self, emissions: torch.FloatTensor, lengths: Optional[torch.Tensor] = None) -> \
List[List[torchaudio.prototype.ctc_decoder.Hypothesis]]
List[List[torchaudio.models.decoder.CTCHypothesis]]
Args:
emissions (torch.FloatTensor): CPU tensor of shape `(batch, frame, num_tokens)` storing sequences of
......@@ -140,7 +140,7 @@ class CTCDecoder:
in time axis of the output Tensor in each batch.
Returns:
List[List[Hypothesis]]:
List[List[CTCHypothesis]]:
List of sorted best hypotheses for each audio sequence in the batch.
"""
......@@ -168,7 +168,7 @@ class CTCDecoder:
nbest_results = results[: self.nbest]
hypos.append(
[
Hypothesis(
CTCHypothesis(
tokens=self._get_tokens(result.tokens),
words=[self.word_dict.get_entry(x) for x in result.words if x >= 0],
score=result.score,
......@@ -293,44 +293,6 @@ def ctc_decoder(
)
def lexicon_decoder(
lexicon: str,
tokens: Union[str, List[str]],
lm: Optional[str] = None,
nbest: int = 1,
beam_size: int = 50,
beam_size_token: Optional[int] = None,
beam_threshold: float = 50,
lm_weight: float = 2,
word_score: float = 0,
unk_score: float = float("-inf"),
sil_score: float = 0,
log_add: bool = False,
blank_token: str = "-",
sil_token: str = "|",
unk_word: str = "<unk>",
) -> CTCDecoder:
warnings.warn("`lexicon_decoder` is now deprecated. Please use `ctc_decoder` instead.")
return ctc_decoder(
lexicon=lexicon,
tokens=tokens,
lm=lm,
nbest=nbest,
beam_size=beam_size,
beam_size_token=beam_size_token,
beam_threshold=beam_threshold,
lm_weight=lm_weight,
word_score=word_score,
unk_score=unk_score,
sil_score=sil_score,
log_add=log_add,
blank_token=blank_token,
sil_token=sil_token,
unk_word=unk_word,
)
def _get_filenames(model: str) -> _PretrainedFiles:
if model not in ["librispeech", "librispeech-3-gram", "librispeech-4-gram"]:
raise ValueError(
......
_INITIALIZED = False
_LAZILY_IMPORTED = [
"Hypothesis",
"CTCDecoder",
"ctc_decoder",
"lexicon_decoder",
"download_pretrained_files",
]
def __getattr__(name: str):
if name in ["ctc_decoder", "lexicon_decoder"]:
import warnings
from torchaudio.models.decoder import ctc_decoder
def _init_extension():
import torchaudio
warnings.warn(
f"{__name__}.{name} has been moved to torchaudio.models.decoder.ctc_decoder",
DeprecationWarning,
)
torchaudio._extension._load_lib("libtorchaudio_decoder")
if name == "lexicon_decoder":
global lexicon_decoder
lexicon_decoder = ctc_decoder
return lexicon_decoder
else:
return ctc_decoder
elif name == "download_pretrained_files":
import warnings
global _INITIALIZED
_INITIALIZED = True
from torchaudio.models.decoder import download_pretrained_files
return download_pretrained_files
def __getattr__(name: str):
if name in _LAZILY_IMPORTED:
if not _INITIALIZED:
_init_extension()
try:
from . import _ctc_decoder
except AttributeError as err:
raise RuntimeError(
"CTC decoder requires the decoder extension. Please set BUILD_CTC_DECODER=1 when building from source."
) from err
item = getattr(_ctc_decoder, name)
globals()[name] = item
return item
raise AttributeError(f"module {__name__} has no attribute {name}")
def __dir__():
return sorted(__all__ + _LAZILY_IMPORTED)
__all__ = []
return ["ctc_decoder", "lexicon_decoder", "download_pretrained_files"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment