Commit 67cdf882 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Reorganize RNN-T components in prototype module (#2110)

Summary:
Regroup RNN-T components under `torchaudio.prototype.models` and `torchaudio.prototype.pipelines`.

Updated docs: https://492321-90321822-gh.circle-artifacts.com/0/docs/prototype.html

Pull Request resolved: https://github.com/pytorch/audio/pull/2110

Reviewed By: carolineechen, mthrok

Differential Revision: D33354116

Pulled By: hwangjeff

fbshipit-source-id: 9cf4afed548cb173d56211c16d31bcfa25a8e4cb
parent 572cd2e2
......@@ -57,8 +57,9 @@ Prototype API References
:caption: Prototype API Reference
prototype
prototype.rnnt
prototype.ctc_decoder
prototype.models
prototype.pipelines
Getting Started
---------------
......
torchaudio.prototype.rnnt
=========================
torchaudio.prototype.models
===========================
.. py:module:: torchaudio.prototype
.. py:module:: torchaudio.prototype.models
.. currentmodule:: torchaudio.prototype.models
The models subpackage contains definitions of models and components for addressing common audio tasks.
.. currentmodule:: torchaudio.prototype
Model Classes
-------------
......@@ -15,7 +18,6 @@ Conformer
.. automethod:: forward
Emformer
~~~~~~~~
......@@ -25,7 +27,6 @@ Emformer
.. automethod:: infer
RNNT
~~~~
......@@ -41,24 +42,26 @@ RNNT
.. automethod:: join
Model Factory Functions
-----------------------
emformer_rnnt_model
~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_model
emformer_rnnt_base
~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_base
emformer_rnnt_model
~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_model
Decoder Classes
---------------
RNNTBeamSearch
~~~~~~~~~~~~
~~~~~~~~~~~~~~
.. autoclass:: RNNTBeamSearch
......@@ -66,43 +69,11 @@ RNNTBeamSearch
.. automethod:: infer
Hypothesis
~~~~~~~~~~
.. autoclass:: Hypothesis
Pipeline Primitives (Pre-trained Models)
----------------------------------------
RNNTBundle
~~~~~~~~~~
.. autoclass:: RNNTBundle
:members: sample_rate, n_fft, n_mels, hop_length, segment_length, right_context_length
.. automethod:: get_decoder
.. automethod:: get_feature_extractor
.. automethod:: get_streaming_feature_extractor
.. automethod:: get_token_processor
.. autoclass:: torchaudio.prototype::RNNTBundle.FeatureExtractor
:special-members: __call__
.. autoclass:: torchaudio.prototype::RNNTBundle.TokenProcessor
:special-members: __call__
EMFORMER_RNNT_BASE_LIBRISPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_LIBRISPEECH
:no-value:
References
----------
......
torchaudio.prototype.pipelines
==============================
.. py:module:: torchaudio.prototype.pipelines
.. currentmodule:: torchaudio.prototype.pipelines
The pipelines subpackage contains APIs to models with pretrained weights and relevant utilities.
RNN-T Streaming/Non-Streaming ASR
---------------------------------
RNNTBundle
~~~~~~~~~~
.. autoclass:: RNNTBundle
:members: sample_rate, n_fft, n_mels, hop_length, segment_length, right_context_length
.. automethod:: get_decoder
.. automethod:: get_feature_extractor
.. automethod:: get_streaming_feature_extractor
.. automethod:: get_token_processor
.. autoclass:: torchaudio.prototype.pipelines::RNNTBundle.FeatureExtractor
:special-members: __call__
.. autoclass:: torchaudio.prototype.pipelines::RNNTBundle.TokenProcessor
:special-members: __call__
EMFORMER_RNNT_BASE_LIBRISPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_LIBRISPEECH
:no-value:
\ No newline at end of file
......@@ -14,8 +14,9 @@ imported explicitly, e.g.
.. code-block:: python
import torchaudio.prototype.rnnt
import torchaudio.prototype.models
.. toctree::
prototype.rnnt
prototype.ctc_decoder
prototype.models
prototype.pipelines
......@@ -9,8 +9,7 @@ import torch
import torchaudio
import torchaudio.functional as F
from pytorch_lightning import LightningModule
from torchaudio.prototype.rnnt import emformer_rnnt_base
from torchaudio.prototype.rnnt_decoder import Hypothesis, RNNTBeamSearch
from torchaudio.prototype.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
......
......@@ -3,7 +3,7 @@ from argparse import ArgumentParser
import torch
import torchaudio
from torchaudio.prototype.rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
def cli_main():
......
import torch
from torchaudio.prototype import Conformer
from torchaudio.prototype.models import Conformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch
from torchaudio.prototype import Emformer
from torchaudio.prototype.models import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch
from torchaudio.prototype import RNNTBeamSearch, emformer_rnnt_model
from torchaudio.prototype.models import RNNTBeamSearch, emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch
from torchaudio.prototype.rnnt import emformer_rnnt_model
from torchaudio.prototype.models import emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
from .conformer import Conformer
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"RNNTBundle",
]
from .conformer import Conformer
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
__all__ = [
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
]
......@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
import torch
from .emformer import Emformer
from torchaudio.prototype.models import Emformer
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
......@@ -426,7 +426,7 @@ class _Joiner(torch.nn.Module):
class RNNT(torch.nn.Module):
r"""torchaudio.prototype.rnnt.RNNT()
r"""torchaudio.prototype.models.RNNT()
Recurrent neural network transducer (RNN-T) model.
......
......@@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Optional, NamedTuple, Tuple
import torch
from .rnnt import RNNT
from torchaudio.prototype.models import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"]
......
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"RNNTBundle",
]
......@@ -8,7 +8,7 @@ from typing import Callable, List, Tuple
import torch
import torchaudio
from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils
from torchaudio.prototype import RNNT, RNNTBeamSearch, emformer_rnnt_base
from torchaudio.prototype.models import RNNT, RNNTBeamSearch, emformer_rnnt_base
__all__ = []
......@@ -157,7 +157,7 @@ class _SentencePieceTokenProcessor(_TokenProcessor):
@dataclass
class RNNTBundle:
"""torchaudio.prototype.rnnt_pipeline.RNNTBundle()
"""torchaudio.prototype.pipelines.RNNTBundle()
Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
inference with an RNN-T model.
......@@ -175,7 +175,7 @@ class RNNTBundle:
Example
>>> import torchaudio
>>> from torchaudio.prototype.rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH
>>> from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
>>> import torch
>>>
>>> # Non-streaming inference.
......@@ -378,7 +378,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
)
EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
The underlying model is constructed by :py:func:`torchaudio.prototypes.emformer_rnnt_base`
The underlying model is constructed by :py:func:`torchaudio.prototype.models.emformer_rnnt_base`
and utilizes weights trained on LibriSpeech using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment