Unverified Commit 6b8f378b authored by moto's avatar moto Committed by GitHub
Browse files

Remove factory functions of tacotron2 and wavernn (#1874)

parent 5600bd25
......@@ -28,27 +28,12 @@ DeepSpeech
Tacotron2
~~~~~~~~~
Model
-----
Tacotron2
^^^^^^^^^
.. autoclass:: Tacotron2
.. automethod:: forward
.. automethod:: infer
Factory Functions
-----------------
tacotron2
^^^^^^^^^
.. autofunction:: tacotron2
Wav2Letter
~~~~~~~~~~
......@@ -131,26 +116,12 @@ import_fairseq_model
WaveRNN
~~~~~~~
Model
-----
WaveRNN
^^^^^^^
.. autoclass:: WaveRNN
.. automethod:: forward
.. automethod:: infer
Factory Functions
-----------------
wavernn
^^^^^^^
.. autofunction:: wavernn
References
~~~~~~~~~~
......
from .wav2letter import Wav2Letter
from .wavernn import WaveRNN, wavernn
from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2
from .tacotron2 import Tacotron2
from .wav2vec2 import (
Wav2Vec2Model,
wav2vec2_model,
......@@ -17,7 +17,6 @@ from .wav2vec2 import (
__all__ = [
'Wav2Letter',
'WaveRNN',
'wavernn',
'ConvTasNet',
'DeepSpeech',
'Wav2Vec2Model',
......@@ -29,5 +28,4 @@ __all__ = [
'hubert_large',
'hubert_xlarge',
'Tacotron2',
'tacotron2',
]
......@@ -27,77 +27,19 @@
import warnings
from math import sqrt
from typing import Tuple, List, Optional, Union, Any, Dict
from typing import Tuple, List, Optional, Union
import torch
from torch import nn
from torch import Tensor
from torch.nn import functional as F
from torch.hub import load_state_dict_from_url
__all__ = [
"Tacotron2",
"tacotron2",
]
_DEFAULT_PARAMETERS = {
'mask_padding': False,
'n_mels': 80,
'n_frames_per_step': 1,
'symbol_embedding_dim': 512,
'encoder_embedding_dim': 512,
'encoder_n_convolution': 3,
'encoder_kernel_size': 5,
'decoder_rnn_dim': 1024,
'decoder_max_step': 2000,
'decoder_dropout': 0.1,
'decoder_early_stopping': True,
'attention_rnn_dim': 1024,
'attention_hidden_dim': 128,
'attention_location_n_filter': 32,
'attention_location_kernel_size': 31,
'attention_dropout': 0.1,
'prenet_dim': 256,
'postnet_n_convolution': 5,
'postnet_kernel_size': 5,
'postnet_embedding_dim': 512,
'gate_threshold': 0.5,
}
_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'tacotron2_english_characters_1500_epochs_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_ljspeech.pth',
dict(
n_symbol=38,
**_DEFAULT_PARAMETERS,
)
),
'tacotron2_english_characters_1500_epochs_wavernn_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth',
dict(
n_symbol=38,
**_DEFAULT_PARAMETERS,
)
),
'tacotron2_english_phonemes_1500_epochs_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_ljspeech.pth',
dict(
n_symbol=96,
**_DEFAULT_PARAMETERS,
)
),
'tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth',
dict(
n_symbol=96,
**_DEFAULT_PARAMETERS,
)
)
}
def _get_linear_layer(
in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear"
) -> torch.nn.Linear:
......@@ -1165,60 +1107,3 @@ class Tacotron2(nn.Module):
alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
return mel_outputs_postnet, mel_specgram_lengths, alignments
def tacotron2(checkpoint_name: str) -> Tacotron2:
r"""Get pretrained Tacotron2 model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"tacotron2_english_characters_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_
with default parameters.
- ``"tacotron2_english_characters_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
For the parameters, the `win_length` is set to 1100, `hop_length` to 275,
`n_fft` to 2048, `mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
- ``"tacotron2_english_phonemes_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`.
- ``"tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`,
`win_length` is set to 1100, `hop_length` to 275, `n_fft` to 2048,
`mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError(
f"Unexpected checkpoint_name: '{checkpoint_name}'. "
f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}")
url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = Tacotron2(**configs)
state_dict = load_state_dict_from_url(url, progress=False)
model.load_state_dict(state_dict)
return model
from typing import List, Tuple, Dict, Any, Optional
from typing import List, Tuple, Optional
import math
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
__all__ = [
"ResBlock",
......@@ -14,29 +12,9 @@ __all__ = [
"Stretch2d",
"UpsampleNetwork",
"WaveRNN",
"wavernn",
]
_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'wavernn_10k_epochs_8bits_ljspeech': (
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth',
{
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
}
)
}
class ResBlock(nn.Module):
r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
......@@ -424,28 +402,3 @@ class WaveRNN(nn.Module):
output.append(x)
return torch.stack(output).permute(1, 2, 0), lengths
def wavernn(checkpoint_name: str) -> WaveRNN:
r"""Get pretrained WaveRNN model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"wavernn_10k_epochs_8bits_ljspeech"``:
WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
The model is trained using the default parameters and code of the
`examples/pipeline_wavernn/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_wavernn>`_.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError(
f"Unexpected checkpoint_name: '{checkpoint_name}'. "
f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}")
url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = WaveRNN(**configs)
state_dict = load_state_dict_from_url(url, progress=False)
model.load_state_dict(state_dict)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment