Unverified Commit 352d63c5 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Move Tacotron2 out of prototype (#1714)

parent 0f603eb9
...@@ -40,7 +40,6 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio ...@@ -40,7 +40,6 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
kaldi_io kaldi_io
utils utils
rnnt_loss rnnt_loss
tacotron2
.. toctree:: .. toctree::
......
...@@ -25,6 +25,24 @@ DeepSpeech ...@@ -25,6 +25,24 @@ DeepSpeech
.. automethod:: forward .. automethod:: forward
Tacotron2
~~~~~~~~~
.. autoclass:: Tacotron2
.. automethod:: forward
.. automethod:: infer
Factory Functions
-----------------
tacotron2
---------
.. autofunction:: tacotron2
Wav2Letter Wav2Letter
~~~~~~~~~~ ~~~~~~~~~~
......
.. role:: hidden
:class: hidden-section
torchaudio.prototype.tacotron2
==============================
.. currentmodule:: torchaudio.prototype.tacotron2
.. note::
The Tacotron2 model is a prototype feature, see `here <https://pytorch.org/audio>`_ to learn more about the nomenclature.
It is only available within the nightlies, and also needs to be imported
explicitly using: :code:`from torchaudio.prototype.Tacotron2 import Tacotron2, tacotron2`.
Tacotron2
~~~~~~~~~
.. autoclass:: Tacotron2
.. automethod:: forward
.. automethod:: infer
Factory Functions
-----------------
tacotron2
---------
.. autofunction:: tacotron2
References
~~~~~~~~~~
.. footbibliography::
...@@ -11,8 +11,8 @@ import sys ...@@ -11,8 +11,8 @@ import sys
import torch import torch
import torchaudio import torchaudio
import numpy as np import numpy as np
from torchaudio.prototype.tacotron2 import Tacotron2 from torchaudio.models import Tacotron2
from torchaudio.prototype.tacotron2 import tacotron2 as pretrained_tacotron2 from torchaudio.models import tacotron2 as pretrained_tacotron2
from utils import prepare_input_sequence from utils import prepare_input_sequence
from datasets import InverseSpectralNormalization from datasets import InverseSpectralNormalization
...@@ -28,7 +28,7 @@ def parse_args(): ...@@ -28,7 +28,7 @@ def parse_args():
r""" r"""
Parse commandline arguments. Parse commandline arguments.
""" """
from torchaudio.prototype.tacotron2 import _MODEL_CONFIG_AND_URLS as tacotron2_config_and_urls from torchaudio.models.tacotron2 import _MODEL_CONFIG_AND_URLS as tacotron2_config_and_urls
from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS as wavernn_config_and_urls from torchaudio.models.wavernn import _MODEL_CONFIG_AND_URLS as wavernn_config_and_urls
parser = argparse.ArgumentParser(description=__doc__) parser = argparse.ArgumentParser(description=__doc__)
......
...@@ -44,7 +44,7 @@ import torch.distributed as dist ...@@ -44,7 +44,7 @@ import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.optim import Adam from torch.optim import Adam
from torchaudio.prototype.tacotron2 import Tacotron2 from torchaudio.models import Tacotron2
from tqdm import tqdm from tqdm import tqdm
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
plt.switch_backend('agg') plt.switch_backend('agg')
......
from typing import Tuple from typing import Tuple
import torch import torch
from torch import Tensor from torch import Tensor
from torchaudio.prototype.tacotron2 import Tacotron2, _Encoder, _Decoder from torchaudio.models import Tacotron2
from torchaudio.models.tacotron2 import _Encoder, _Decoder
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TestBaseMixin, TestBaseMixin,
TempDirMixin, TempDirMixin,
......
...@@ -2,6 +2,7 @@ from .wav2letter import Wav2Letter ...@@ -2,6 +2,7 @@ from .wav2letter import Wav2Letter
from .wavernn import WaveRNN, wavernn from .wavernn import WaveRNN, wavernn
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2
from .wav2vec2 import ( from .wav2vec2 import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_base, wav2vec2_base,
...@@ -20,4 +21,6 @@ __all__ = [ ...@@ -20,4 +21,6 @@ __all__ = [
'wav2vec2_base', 'wav2vec2_base',
'wav2vec2_large', 'wav2vec2_large',
'wav2vec2_large_lv60k', 'wav2vec2_large_lv60k',
'Tacotron2',
'tacotron2',
] ]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment