Commit aca5591c authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Move ASR features out of prototype (#2187)

Summary:
Moves ASR features out of `torchaudio.prototype`. Specifically, merges contents of `torchaudio.prototype.models` into `torchaudio.models` and contents of `torchaudio.prototype.pipelines` into `torchaudio.pipelines` and updates refs, tests, and docs accordingly.

Pull Request resolved: https://github.com/pytorch/audio/pull/2187

Reviewed By: nateanl, mthrok

Differential Revision: D33918092

Pulled By: hwangjeff

fbshipit-source-id: f003f289a7e5d7d43f85b7c270b58bdf2ed6344c
parent ff15ba1b
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
from torchaudio_unittest.models.rnnt_decoder.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
class RNNTBeamSearchFloat32CPUTest(RNNTBeamSearchTestImpl, PytorchTestCase):
......
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
from torchaudio_unittest.models.rnnt_decoder.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
@skipIfNoCuda
......
import torch
from torchaudio.prototype.models import RNNTBeamSearch, emformer_rnnt_model
from torchaudio.models import RNNTBeamSearch, emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
from .conformer import Conformer
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter
from .wav2vec2 import (
......@@ -19,6 +23,7 @@ from .wav2vec2 import (
)
from .wavernn import WaveRNN
__all__ = [
"Wav2Letter",
"WaveRNN",
......@@ -38,4 +43,11 @@ __all__ = [
"hubert_pretrain_large",
"hubert_pretrain_xlarge",
"Tacotron2",
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
]
from typing import List, Optional, Tuple
import torch
from torchaudio.prototype.models import Emformer
from torchaudio.models import Emformer
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
......@@ -278,7 +278,7 @@ class _Transcriber(torch.nn.Module):
class _Predictor(torch.nn.Module):
r"""Recurrent neural network transducer (RNN-T) transcription network.
r"""Recurrent neural network transducer (RNN-T) prediction network.
Args:
num_symbols (int): size of target token lexicon.
......@@ -425,7 +425,7 @@ class _Joiner(torch.nn.Module):
class RNNT(torch.nn.Module):
r"""torchaudio.prototype.models.RNNT()
r"""torchaudio.models.RNNT()
Recurrent neural network transducer (RNN-T) model.
......
from typing import Callable, Dict, List, Optional, NamedTuple, Tuple
import torch
from torchaudio.prototype.models import RNNT
from torchaudio.models import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"]
......
......@@ -32,6 +32,8 @@ from ._wav2vec2.impl import (
HUBERT_ASR_LARGE,
HUBERT_ASR_XLARGE,
)
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [
"Wav2Vec2Bundle",
......@@ -64,4 +66,6 @@ __all__ = [
"TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
"TACOTRON2_WAVERNN_CHAR_LJSPEECH",
"TACOTRON2_WAVERNN_PHONE_LJSPEECH",
"RNNTBundle",
"EMFORMER_RNNT_BASE_LIBRISPEECH",
]
......@@ -9,7 +9,7 @@ from typing import Callable, List, Tuple
import torch
import torchaudio
from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils
from torchaudio.prototype.models import RNNT, RNNTBeamSearch, emformer_rnnt_base
from torchaudio.models import RNNT, RNNTBeamSearch, emformer_rnnt_base
__all__ = []
......@@ -158,7 +158,7 @@ class _SentencePieceTokenProcessor(_TokenProcessor):
@dataclass
class RNNTBundle:
"""torchaudio.prototype.pipelines.RNNTBundle()
"""torchaudio.pipelines.RNNTBundle()
Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
inference with an RNN-T model.
......@@ -176,7 +176,7 @@ class RNNTBundle:
Example
>>> import torchaudio
>>> from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
>>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
>>> import torch
>>>
>>> # Non-streaming inference.
......@@ -379,7 +379,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
)
EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.emformer_rnnt_base`
The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on LibriSpeech using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments.
......
from .conformer import Conformer
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
__all__ = [
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
]
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"RNNTBundle",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment