"vscode:/vscode.git/clone" did not exist on "7ead28c315b8ec0b49891c323fc3661ff75b9fc4"
Commit 67cdf882 authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Reorganize RNN-T components in prototype module (#2110)

Summary:
Regroup RNN-T components under `torchaudio.prototype.models` and `torchaudio.prototype.pipelines`.

Updated docs: https://492321-90321822-gh.circle-artifacts.com/0/docs/prototype.html

Pull Request resolved: https://github.com/pytorch/audio/pull/2110

Reviewed By: carolineechen, mthrok

Differential Revision: D33354116

Pulled By: hwangjeff

fbshipit-source-id: 9cf4afed548cb173d56211c16d31bcfa25a8e4cb
parent 572cd2e2
...@@ -57,8 +57,9 @@ Prototype API References ...@@ -57,8 +57,9 @@ Prototype API References
:caption: Prototype API Reference :caption: Prototype API Reference
prototype prototype
prototype.rnnt
prototype.ctc_decoder prototype.ctc_decoder
prototype.models
prototype.pipelines
Getting Started Getting Started
--------------- ---------------
......
torchaudio.prototype.rnnt torchaudio.prototype.models
========================= ===========================
.. py:module:: torchaudio.prototype .. py:module:: torchaudio.prototype.models
.. currentmodule:: torchaudio.prototype.models
The models subpackage contains definitions of models and components for addressing common audio tasks.
.. currentmodule:: torchaudio.prototype
Model Classes Model Classes
------------- -------------
...@@ -15,7 +18,6 @@ Conformer ...@@ -15,7 +18,6 @@ Conformer
.. automethod:: forward .. automethod:: forward
Emformer Emformer
~~~~~~~~ ~~~~~~~~
...@@ -25,7 +27,6 @@ Emformer ...@@ -25,7 +27,6 @@ Emformer
.. automethod:: infer .. automethod:: infer
RNNT RNNT
~~~~ ~~~~
...@@ -41,24 +42,26 @@ RNNT ...@@ -41,24 +42,26 @@ RNNT
.. automethod:: join .. automethod:: join
Model Factory Functions Model Factory Functions
----------------------- -----------------------
emformer_rnnt_model
~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_model
emformer_rnnt_base emformer_rnnt_base
~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_base .. autofunction:: emformer_rnnt_base
emformer_rnnt_model
~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_model
Decoder Classes Decoder Classes
--------------- ---------------
RNNTBeamSearch RNNTBeamSearch
~~~~~~~~~~~~ ~~~~~~~~~~~~~~
.. autoclass:: RNNTBeamSearch .. autoclass:: RNNTBeamSearch
...@@ -66,43 +69,11 @@ RNNTBeamSearch ...@@ -66,43 +69,11 @@ RNNTBeamSearch
.. automethod:: infer .. automethod:: infer
Hypothesis Hypothesis
~~~~~~~~~~ ~~~~~~~~~~
.. autoclass:: Hypothesis .. autoclass:: Hypothesis
Pipeline Primitives (Pre-trained Models)
----------------------------------------
RNNTBundle
~~~~~~~~~~
.. autoclass:: RNNTBundle
:members: sample_rate, n_fft, n_mels, hop_length, segment_length, right_context_length
.. automethod:: get_decoder
.. automethod:: get_feature_extractor
.. automethod:: get_streaming_feature_extractor
.. automethod:: get_token_processor
.. autoclass:: torchaudio.prototype::RNNTBundle.FeatureExtractor
:special-members: __call__
.. autoclass:: torchaudio.prototype::RNNTBundle.TokenProcessor
:special-members: __call__
EMFORMER_RNNT_BASE_LIBRISPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_LIBRISPEECH
:no-value:
References References
---------- ----------
......
torchaudio.prototype.pipelines
==============================
.. py:module:: torchaudio.prototype.pipelines
.. currentmodule:: torchaudio.prototype.pipelines
The pipelines subpackage contains APIs to models with pretrained weights and relevant utilities.
RNN-T Streaming/Non-Streaming ASR
---------------------------------
RNNTBundle
~~~~~~~~~~
.. autoclass:: RNNTBundle
:members: sample_rate, n_fft, n_mels, hop_length, segment_length, right_context_length
.. automethod:: get_decoder
.. automethod:: get_feature_extractor
.. automethod:: get_streaming_feature_extractor
.. automethod:: get_token_processor
.. autoclass:: torchaudio.prototype.pipelines::RNNTBundle.FeatureExtractor
:special-members: __call__
.. autoclass:: torchaudio.prototype.pipelines::RNNTBundle.TokenProcessor
:special-members: __call__
EMFORMER_RNNT_BASE_LIBRISPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_LIBRISPEECH
:no-value:
\ No newline at end of file
...@@ -14,8 +14,9 @@ imported explicitly, e.g. ...@@ -14,8 +14,9 @@ imported explicitly, e.g.
.. code-block:: python .. code-block:: python
import torchaudio.prototype.rnnt import torchaudio.prototype.models
.. toctree:: .. toctree::
prototype.rnnt
prototype.ctc_decoder prototype.ctc_decoder
prototype.models
prototype.pipelines
...@@ -9,8 +9,7 @@ import torch ...@@ -9,8 +9,7 @@ import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from torchaudio.prototype.rnnt import emformer_rnnt_base from torchaudio.prototype.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base
from torchaudio.prototype.rnnt_decoder import Hypothesis, RNNTBeamSearch
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"]) Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
......
...@@ -3,7 +3,7 @@ from argparse import ArgumentParser ...@@ -3,7 +3,7 @@ from argparse import ArgumentParser
import torch import torch
import torchaudio import torchaudio
from torchaudio.prototype.rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
def cli_main(): def cli_main():
......
import torch import torch
from torchaudio.prototype import Conformer from torchaudio.prototype.models import Conformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch import torch
from torchaudio.prototype import Emformer from torchaudio.prototype.models import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch import torch
from torchaudio.prototype import RNNTBeamSearch, emformer_rnnt_model from torchaudio.prototype.models import RNNTBeamSearch, emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch import torch
from torchaudio.prototype.rnnt import emformer_rnnt_model from torchaudio.prototype.models import emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
from .conformer import Conformer
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"RNNTBundle",
]
from .conformer import Conformer
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
__all__ = [
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
]
...@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple ...@@ -2,7 +2,7 @@ from typing import List, Optional, Tuple
import torch import torch
from .emformer import Emformer from torchaudio.prototype.models import Emformer
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"] __all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
...@@ -426,7 +426,7 @@ class _Joiner(torch.nn.Module): ...@@ -426,7 +426,7 @@ class _Joiner(torch.nn.Module):
class RNNT(torch.nn.Module): class RNNT(torch.nn.Module):
r"""torchaudio.prototype.rnnt.RNNT() r"""torchaudio.prototype.models.RNNT()
Recurrent neural network transducer (RNN-T) model. Recurrent neural network transducer (RNN-T) model.
......
...@@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Optional, NamedTuple, Tuple ...@@ -2,7 +2,7 @@ from typing import Callable, Dict, List, Optional, NamedTuple, Tuple
import torch import torch
from .rnnt import RNNT from torchaudio.prototype.models import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"] __all__ = ["Hypothesis", "RNNTBeamSearch"]
......
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"RNNTBundle",
]
...@@ -8,7 +8,7 @@ from typing import Callable, List, Tuple ...@@ -8,7 +8,7 @@ from typing import Callable, List, Tuple
import torch import torch
import torchaudio import torchaudio
from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils
from torchaudio.prototype import RNNT, RNNTBeamSearch, emformer_rnnt_base from torchaudio.prototype.models import RNNT, RNNTBeamSearch, emformer_rnnt_base
__all__ = [] __all__ = []
...@@ -157,7 +157,7 @@ class _SentencePieceTokenProcessor(_TokenProcessor): ...@@ -157,7 +157,7 @@ class _SentencePieceTokenProcessor(_TokenProcessor):
@dataclass @dataclass
class RNNTBundle: class RNNTBundle:
"""torchaudio.prototype.rnnt_pipeline.RNNTBundle() """torchaudio.prototype.pipelines.RNNTBundle()
Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text) Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
inference with an RNN-T model. inference with an RNN-T model.
...@@ -175,7 +175,7 @@ class RNNTBundle: ...@@ -175,7 +175,7 @@ class RNNTBundle:
Example Example
>>> import torchaudio >>> import torchaudio
>>> from torchaudio.prototype.rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH >>> from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
>>> import torch >>> import torch
>>> >>>
>>> # Non-streaming inference. >>> # Non-streaming inference.
...@@ -378,7 +378,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle( ...@@ -378,7 +378,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
) )
EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference. EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
The underlying model is constructed by :py:func:`torchaudio.prototypes.emformer_rnnt_base` The underlying model is constructed by :py:func:`torchaudio.prototype.models.emformer_rnnt_base`
and utilizes weights trained on LibriSpeech using training script ``train.py`` and utilizes weights trained on LibriSpeech using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments. `here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment