Commit aca5591c authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Move ASR features out of prototype (#2187)

Summary:
Moves ASR features out of `torchaudio.prototype`. Specifically, merges contents of `torchaudio.prototype.models` into `torchaudio.models` and contents of `torchaudio.prototype.pipelines` into `torchaudio.pipelines` and updates refs, tests, and docs accordingly.

Pull Request resolved: https://github.com/pytorch/audio/pull/2187

Reviewed By: nateanl, mthrok

Differential Revision: D33918092

Pulled By: hwangjeff

fbshipit-source-id: f003f289a7e5d7d43f85b7c270b58bdf2ed6344c
parent ff15ba1b
...@@ -10,6 +10,12 @@ torchaudio.models ...@@ -10,6 +10,12 @@ torchaudio.models
The models subpackage contains definitions of models for addressing common audio tasks. The models subpackage contains definitions of models for addressing common audio tasks.
Conformer
~~~~~~~~~
.. autoclass:: Conformer
.. automethod:: forward
ConvTasNet ConvTasNet
~~~~~~~~~~ ~~~~~~~~~~
...@@ -26,6 +32,67 @@ DeepSpeech ...@@ -26,6 +32,67 @@ DeepSpeech
.. automethod:: forward .. automethod:: forward
Emformer
~~~~~~~~
.. autoclass:: Emformer
.. automethod:: forward
.. automethod:: infer
RNN-T
~~~~~
Model
-----
RNNT
^^^^
.. autoclass:: RNNT
.. automethod:: forward
.. automethod:: transcribe_streaming
.. automethod:: transcribe
.. automethod:: predict
.. automethod:: join
Factory Functions
-----------------
emformer_rnnt_model
^^^^^^^^^^^^^^^^^^^
.. autofunction:: emformer_rnnt_model
emformer_rnnt_base
^^^^^^^^^^^^^^^^^^
.. autofunction:: emformer_rnnt_base
Decoder
-------
RNNTBeamSearch
^^^^^^^^^^^^^^
.. autoclass:: RNNTBeamSearch
.. automethod:: forward
.. automethod:: infer
Hypothesis
^^^^^^^^^^
.. autoclass:: Hypothesis
Tacotron2 Tacotron2
~~~~~~~~~ ~~~~~~~~~
......
...@@ -7,6 +7,44 @@ torchaudio.pipelines ...@@ -7,6 +7,44 @@ torchaudio.pipelines
The pipelines subpackage contains API to access the models with pretrained weights, and information/helper functions associated the pretrained weights. The pipelines subpackage contains API to access the models with pretrained weights, and information/helper functions associated the pretrained weights.
RNN-T Streaming/Non-Streaming ASR
---------------------------------
RNNTBundle
~~~~~~~~~~
.. autoclass:: RNNTBundle
:members: sample_rate, n_fft, n_mels, hop_length, segment_length, right_context_length
.. automethod:: get_decoder() -> torchaudio.models.RNNTBeamSearch
.. automethod:: get_feature_extractor() -> RNNTBundle.FeatureExtractor
.. automethod:: get_streaming_feature_extractor() -> RNNTBundle.FeatureExtractor
.. automethod:: get_token_processor() -> RNNTBundle.TokenProcessor
RNNTBundle - FeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchaudio.pipelines::RNNTBundle.FeatureExtractor
:special-members: __call__
RNNTBundle - TokenProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchaudio.pipelines::RNNTBundle.TokenProcessor
:special-members: __call__
EMFORMER_RNNT_BASE_LIBRISPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_LIBRISPEECH
:no-value:
wav2vec 2.0 / HuBERT - Representation Learning wav2vec 2.0 / HuBERT - Representation Learning
---------------------------------------------- ----------------------------------------------
......
torchaudio.prototype.models
===========================
.. py:module:: torchaudio.prototype.models
.. currentmodule:: torchaudio.prototype.models
The models subpackage contains definitions of models and components for addressing common audio tasks.
Model Classes
-------------
Conformer
~~~~~~~~~
.. autoclass:: Conformer
.. automethod:: forward
Emformer
~~~~~~~~
.. autoclass:: Emformer
.. automethod:: forward
.. automethod:: infer
RNNT
~~~~
.. autoclass:: RNNT
.. automethod:: forward
.. automethod:: transcribe_streaming
.. automethod:: transcribe
.. automethod:: predict
.. automethod:: join
Model Factory Functions
-----------------------
emformer_rnnt_model
~~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_model
emformer_rnnt_base
~~~~~~~~~~~~~~~~~~
.. autofunction:: emformer_rnnt_base
Decoder Classes
---------------
RNNTBeamSearch
~~~~~~~~~~~~~~
.. autoclass:: RNNTBeamSearch
.. automethod:: forward
.. automethod:: infer
Hypothesis
~~~~~~~~~~
.. autoclass:: Hypothesis
References
----------
.. footbibliography::
torchaudio.prototype.pipelines
==============================
.. py:module:: torchaudio.prototype.pipelines
.. currentmodule:: torchaudio.prototype.pipelines
The pipelines subpackage contains APIs to models with pretrained weights and relevant utilities.
RNN-T Streaming/Non-Streaming ASR
---------------------------------
RNNTBundle
~~~~~~~~~~
.. autoclass:: RNNTBundle
:members: sample_rate, n_fft, n_mels, hop_length, segment_length, right_context_length
.. automethod:: get_decoder() -> torchaudio.prototype.models.RNNTBeamSearch
.. automethod:: get_feature_extractor() -> RNNTBundle.FeatureExtractor
.. automethod:: get_streaming_feature_extractor() -> RNNTBundle.FeatureExtractor
.. automethod:: get_token_processor() -> RNNTBundle.TokenProcessor
RNNTBundle - FeatureExtractor
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchaudio.prototype.pipelines::RNNTBundle.FeatureExtractor
:special-members: __call__
RNNTBundle - TokenProcessor
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torchaudio.prototype.pipelines::RNNTBundle.TokenProcessor
:special-members: __call__
EMFORMER_RNNT_BASE_LIBRISPEECH
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. container:: py attribute
.. autodata:: EMFORMER_RNNT_BASE_LIBRISPEECH
:no-value:
...@@ -14,9 +14,7 @@ imported explicitly, e.g. ...@@ -14,9 +14,7 @@ imported explicitly, e.g.
.. code-block:: python .. code-block:: python
import torchaudio.prototype.models import torchaudio.prototype.ctc_decoder
.. toctree:: .. toctree::
prototype.ctc_decoder prototype.ctc_decoder
prototype.models
prototype.pipelines
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
from pytorch_lightning import LightningModule from pytorch_lightning import LightningModule
from torchaudio.prototype.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base from torchaudio.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base
from utils import GAIN, piecewise_linear_log, spectrogram_transform from utils import GAIN, piecewise_linear_log, spectrogram_transform
......
import pytest import pytest
import torchaudio import torchaudio
from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl from torchaudio_unittest.models.conformer.conformer_test_impl import ConformerTestImpl
class ConformerFloat32CPUTest(ConformerTestImpl, PytorchTestCase): class ConformerFloat32CPUTest(ConformerTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.conformer_test_impl import ConformerTestImpl from torchaudio_unittest.models.conformer.conformer_test_impl import ConformerTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio.prototype.models import Conformer from torchaudio.models import Conformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestImpl
class EmformerFloat32CPUTest(EmformerTestImpl, PytorchTestCase): class EmformerFloat32CPUTest(EmformerTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.emformer_test_impl import EmformerTestImpl from torchaudio_unittest.models.emformer.emformer_test_impl import EmformerTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio.prototype.models import Emformer from torchaudio.models import Emformer
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl from torchaudio_unittest.models.rnnt.rnnt_test_impl import RNNTTestImpl
class RNNTFloat32CPUTest(RNNTTestImpl, PytorchTestCase): class RNNTFloat32CPUTest(RNNTTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.rnnt_test_impl import RNNTTestImpl from torchaudio_unittest.models.rnnt.rnnt_test_impl import RNNTTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio.prototype.models import emformer_rnnt_model from torchaudio.models import emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment