Commit aca5591c authored by hwangjeff's avatar hwangjeff Committed by Facebook GitHub Bot
Browse files

Move ASR features out of prototype (#2187)

Summary:
Moves ASR features out of `torchaudio.prototype`. Specifically, merges contents of `torchaudio.prototype.models` into `torchaudio.models` and contents of `torchaudio.prototype.pipelines` into `torchaudio.pipelines` and updates refs, tests, and docs accordingly.

Pull Request resolved: https://github.com/pytorch/audio/pull/2187

Reviewed By: nateanl, mthrok

Differential Revision: D33918092

Pulled By: hwangjeff

fbshipit-source-id: f003f289a7e5d7d43f85b7c270b58bdf2ed6344c
parent ff15ba1b
import torch import torch
from torchaudio_unittest.common_utils import PytorchTestCase from torchaudio_unittest.common_utils import PytorchTestCase
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl from torchaudio_unittest.models.rnnt_decoder.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
class RNNTBeamSearchFloat32CPUTest(RNNTBeamSearchTestImpl, PytorchTestCase): class RNNTBeamSearchFloat32CPUTest(RNNTBeamSearchTestImpl, PytorchTestCase):
......
import torch import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from torchaudio_unittest.prototype.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl from torchaudio_unittest.models.rnnt_decoder.rnnt_decoder_test_impl import RNNTBeamSearchTestImpl
@skipIfNoCuda @skipIfNoCuda
......
import torch import torch
from torchaudio.prototype.models import RNNTBeamSearch, emformer_rnnt_model from torchaudio.models import RNNTBeamSearch, emformer_rnnt_model
from torchaudio_unittest.common_utils import TestBaseMixin, torch_script from torchaudio_unittest.common_utils import TestBaseMixin, torch_script
......
from .conformer import Conformer
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech from .deepspeech import DeepSpeech
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
from .tacotron2 import Tacotron2 from .tacotron2 import Tacotron2
from .wav2letter import Wav2Letter from .wav2letter import Wav2Letter
from .wav2vec2 import ( from .wav2vec2 import (
...@@ -19,6 +23,7 @@ from .wav2vec2 import ( ...@@ -19,6 +23,7 @@ from .wav2vec2 import (
) )
from .wavernn import WaveRNN from .wavernn import WaveRNN
__all__ = [ __all__ = [
"Wav2Letter", "Wav2Letter",
"WaveRNN", "WaveRNN",
...@@ -38,4 +43,11 @@ __all__ = [ ...@@ -38,4 +43,11 @@ __all__ = [
"hubert_pretrain_large", "hubert_pretrain_large",
"hubert_pretrain_xlarge", "hubert_pretrain_xlarge",
"Tacotron2", "Tacotron2",
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
] ]
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from torchaudio.prototype.models import Emformer from torchaudio.models import Emformer
__all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"] __all__ = ["RNNT", "emformer_rnnt_base", "emformer_rnnt_model"]
...@@ -278,7 +278,7 @@ class _Transcriber(torch.nn.Module): ...@@ -278,7 +278,7 @@ class _Transcriber(torch.nn.Module):
class _Predictor(torch.nn.Module): class _Predictor(torch.nn.Module):
r"""Recurrent neural network transducer (RNN-T) transcription network. r"""Recurrent neural network transducer (RNN-T) prediction network.
Args: Args:
num_symbols (int): size of target token lexicon. num_symbols (int): size of target token lexicon.
...@@ -425,7 +425,7 @@ class _Joiner(torch.nn.Module): ...@@ -425,7 +425,7 @@ class _Joiner(torch.nn.Module):
class RNNT(torch.nn.Module): class RNNT(torch.nn.Module):
r"""torchaudio.prototype.models.RNNT() r"""torchaudio.models.RNNT()
Recurrent neural network transducer (RNN-T) model. Recurrent neural network transducer (RNN-T) model.
......
from typing import Callable, Dict, List, Optional, NamedTuple, Tuple from typing import Callable, Dict, List, Optional, NamedTuple, Tuple
import torch import torch
from torchaudio.prototype.models import RNNT from torchaudio.models import RNNT
__all__ = ["Hypothesis", "RNNTBeamSearch"] __all__ = ["Hypothesis", "RNNTBeamSearch"]
......
...@@ -32,6 +32,8 @@ from ._wav2vec2.impl import ( ...@@ -32,6 +32,8 @@ from ._wav2vec2.impl import (
HUBERT_ASR_LARGE, HUBERT_ASR_LARGE,
HUBERT_ASR_XLARGE, HUBERT_ASR_XLARGE,
) )
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [ __all__ = [
"Wav2Vec2Bundle", "Wav2Vec2Bundle",
...@@ -64,4 +66,6 @@ __all__ = [ ...@@ -64,4 +66,6 @@ __all__ = [
"TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH", "TACOTRON2_GRIFFINLIM_PHONE_LJSPEECH",
"TACOTRON2_WAVERNN_CHAR_LJSPEECH", "TACOTRON2_WAVERNN_CHAR_LJSPEECH",
"TACOTRON2_WAVERNN_PHONE_LJSPEECH", "TACOTRON2_WAVERNN_PHONE_LJSPEECH",
"RNNTBundle",
"EMFORMER_RNNT_BASE_LIBRISPEECH",
] ]
...@@ -9,7 +9,7 @@ from typing import Callable, List, Tuple ...@@ -9,7 +9,7 @@ from typing import Callable, List, Tuple
import torch import torch
import torchaudio import torchaudio
from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils
from torchaudio.prototype.models import RNNT, RNNTBeamSearch, emformer_rnnt_base from torchaudio.models import RNNT, RNNTBeamSearch, emformer_rnnt_base
__all__ = [] __all__ = []
...@@ -158,7 +158,7 @@ class _SentencePieceTokenProcessor(_TokenProcessor): ...@@ -158,7 +158,7 @@ class _SentencePieceTokenProcessor(_TokenProcessor):
@dataclass @dataclass
class RNNTBundle: class RNNTBundle:
"""torchaudio.prototype.pipelines.RNNTBundle() """torchaudio.pipelines.RNNTBundle()
Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text) Dataclass that bundles components for performing automatic speech recognition (ASR, speech-to-text)
inference with an RNN-T model. inference with an RNN-T model.
...@@ -176,7 +176,7 @@ class RNNTBundle: ...@@ -176,7 +176,7 @@ class RNNTBundle:
Example Example
>>> import torchaudio >>> import torchaudio
>>> from torchaudio.prototype.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH >>> from torchaudio.pipelines import EMFORMER_RNNT_BASE_LIBRISPEECH
>>> import torch >>> import torch
>>> >>>
>>> # Non-streaming inference. >>> # Non-streaming inference.
...@@ -379,7 +379,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle( ...@@ -379,7 +379,7 @@ EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
) )
EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference. EMFORMER_RNNT_BASE_LIBRISPEECH.__doc__ = """Pre-trained Emformer-RNNT-based ASR pipeline capable of performing both streaming and non-streaming inference.
The underlying model is constructed by :py:func:`torchaudio.prototype.models.emformer_rnnt_base` The underlying model is constructed by :py:func:`torchaudio.models.emformer_rnnt_base`
and utilizes weights trained on LibriSpeech using training script ``train.py`` and utilizes weights trained on LibriSpeech using training script ``train.py``
`here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments. `here <https://github.com/pytorch/audio/tree/main/examples/asr/librispeech_emformer_rnnt>`__ with default arguments.
......
from .conformer import Conformer
from .emformer import Emformer
from .rnnt import RNNT, emformer_rnnt_base, emformer_rnnt_model
from .rnnt_decoder import Hypothesis, RNNTBeamSearch
__all__ = [
"Conformer",
"Emformer",
"Hypothesis",
"RNNT",
"RNNTBeamSearch",
"emformer_rnnt_base",
"emformer_rnnt_model",
]
from .rnnt_pipeline import EMFORMER_RNNT_BASE_LIBRISPEECH, RNNTBundle
__all__ = [
"EMFORMER_RNNT_BASE_LIBRISPEECH",
"RNNTBundle",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment