Unverified Commit 6b8f378b authored by moto's avatar moto Committed by GitHub
Browse files

Remove factory functions of tacotron2 and wavernn (#1874)

parent 5600bd25
...@@ -28,27 +28,12 @@ DeepSpeech ...@@ -28,27 +28,12 @@ DeepSpeech
Tacotron2 Tacotron2
~~~~~~~~~ ~~~~~~~~~
Model
-----
Tacotron2
^^^^^^^^^
.. autoclass:: Tacotron2 .. autoclass:: Tacotron2
.. automethod:: forward .. automethod:: forward
.. automethod:: infer .. automethod:: infer
Factory Functions
-----------------
tacotron2
^^^^^^^^^
.. autofunction:: tacotron2
Wav2Letter Wav2Letter
~~~~~~~~~~ ~~~~~~~~~~
...@@ -131,26 +116,12 @@ import_fairseq_model ...@@ -131,26 +116,12 @@ import_fairseq_model
WaveRNN WaveRNN
~~~~~~~ ~~~~~~~
Model
-----
WaveRNN
^^^^^^^
.. autoclass:: WaveRNN .. autoclass:: WaveRNN
.. automethod:: forward .. automethod:: forward
.. automethod:: infer .. automethod:: infer
Factory Functions
-----------------
wavernn
^^^^^^^
.. autofunction:: wavernn
References References
~~~~~~~~~~ ~~~~~~~~~~
......
from .wav2letter import Wav2Letter from .wav2letter import Wav2Letter
from .wavernn import WaveRNN, wavernn from .wavernn import WaveRNN
from .conv_tasnet import ConvTasNet from .conv_tasnet import ConvTasNet
from .deepspeech import DeepSpeech from .deepspeech import DeepSpeech
from .tacotron2 import Tacotron2, tacotron2 from .tacotron2 import Tacotron2
from .wav2vec2 import ( from .wav2vec2 import (
Wav2Vec2Model, Wav2Vec2Model,
wav2vec2_model, wav2vec2_model,
...@@ -17,7 +17,6 @@ from .wav2vec2 import ( ...@@ -17,7 +17,6 @@ from .wav2vec2 import (
__all__ = [ __all__ = [
'Wav2Letter', 'Wav2Letter',
'WaveRNN', 'WaveRNN',
'wavernn',
'ConvTasNet', 'ConvTasNet',
'DeepSpeech', 'DeepSpeech',
'Wav2Vec2Model', 'Wav2Vec2Model',
...@@ -29,5 +28,4 @@ __all__ = [ ...@@ -29,5 +28,4 @@ __all__ = [
'hubert_large', 'hubert_large',
'hubert_xlarge', 'hubert_xlarge',
'Tacotron2', 'Tacotron2',
'tacotron2',
] ]
...@@ -27,77 +27,19 @@ ...@@ -27,77 +27,19 @@
import warnings import warnings
from math import sqrt from math import sqrt
from typing import Tuple, List, Optional, Union, Any, Dict from typing import Tuple, List, Optional, Union
import torch import torch
from torch import nn from torch import nn
from torch import Tensor from torch import Tensor
from torch.nn import functional as F from torch.nn import functional as F
from torch.hub import load_state_dict_from_url
__all__ = [ __all__ = [
"Tacotron2", "Tacotron2",
"tacotron2",
] ]
_DEFAULT_PARAMETERS = {
'mask_padding': False,
'n_mels': 80,
'n_frames_per_step': 1,
'symbol_embedding_dim': 512,
'encoder_embedding_dim': 512,
'encoder_n_convolution': 3,
'encoder_kernel_size': 5,
'decoder_rnn_dim': 1024,
'decoder_max_step': 2000,
'decoder_dropout': 0.1,
'decoder_early_stopping': True,
'attention_rnn_dim': 1024,
'attention_hidden_dim': 128,
'attention_location_n_filter': 32,
'attention_location_kernel_size': 31,
'attention_dropout': 0.1,
'prenet_dim': 256,
'postnet_n_convolution': 5,
'postnet_kernel_size': 5,
'postnet_embedding_dim': 512,
'gate_threshold': 0.5,
}
_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'tacotron2_english_characters_1500_epochs_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_ljspeech.pth',
dict(
n_symbol=38,
**_DEFAULT_PARAMETERS,
)
),
'tacotron2_english_characters_1500_epochs_wavernn_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth',
dict(
n_symbol=38,
**_DEFAULT_PARAMETERS,
)
),
'tacotron2_english_phonemes_1500_epochs_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_ljspeech.pth',
dict(
n_symbol=96,
**_DEFAULT_PARAMETERS,
)
),
'tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech': (
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth',
dict(
n_symbol=96,
**_DEFAULT_PARAMETERS,
)
)
}
def _get_linear_layer( def _get_linear_layer(
in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear" in_dim: int, out_dim: int, bias: bool = True, w_init_gain: str = "linear"
) -> torch.nn.Linear: ) -> torch.nn.Linear:
...@@ -1165,60 +1107,3 @@ class Tacotron2(nn.Module): ...@@ -1165,60 +1107,3 @@ class Tacotron2(nn.Module):
alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2) alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
return mel_outputs_postnet, mel_specgram_lengths, alignments return mel_outputs_postnet, mel_specgram_lengths, alignments
def tacotron2(checkpoint_name: str) -> Tacotron2:
r"""Get pretrained Tacotron2 model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"tacotron2_english_characters_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_
with default parameters.
- ``"tacotron2_english_characters_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
For the parameters, the `win_length` is set to 1100, `hop_length` to 275,
`n_fft` to 2048, `mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
- ``"tacotron2_english_phonemes_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`.
- ``"tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`,
`win_length` is set to 1100, `hop_length` to 275, `n_fft` to 2048,
`mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError(
f"Unexpected checkpoint_name: '{checkpoint_name}'. "
f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}")
url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = Tacotron2(**configs)
state_dict = load_state_dict_from_url(url, progress=False)
model.load_state_dict(state_dict)
return model
from typing import List, Tuple, Dict, Any, Optional from typing import List, Tuple, Optional
import math import math
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
__all__ = [ __all__ = [
"ResBlock", "ResBlock",
...@@ -14,29 +12,9 @@ __all__ = [ ...@@ -14,29 +12,9 @@ __all__ = [
"Stretch2d", "Stretch2d",
"UpsampleNetwork", "UpsampleNetwork",
"WaveRNN", "WaveRNN",
"wavernn",
] ]
_MODEL_CONFIG_AND_URLS: Dict[str, Tuple[str, Dict[str, Any]]] = {
'wavernn_10k_epochs_8bits_ljspeech': (
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth',
{
'upsample_scales': [5, 5, 11],
'n_classes': 2 ** 8, # n_bits = 8
'hop_length': 275,
'n_res_block': 10,
'n_rnn': 512,
'n_fc': 512,
'kernel_size': 5,
'n_freq': 80,
'n_hidden': 128,
'n_output': 128
}
)
}
class ResBlock(nn.Module): class ResBlock(nn.Module):
r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`]. r"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
...@@ -424,28 +402,3 @@ class WaveRNN(nn.Module): ...@@ -424,28 +402,3 @@ class WaveRNN(nn.Module):
output.append(x) output.append(x)
return torch.stack(output).permute(1, 2, 0), lengths return torch.stack(output).permute(1, 2, 0), lengths
def wavernn(checkpoint_name: str) -> WaveRNN:
r"""Get pretrained WaveRNN model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"wavernn_10k_epochs_8bits_ljspeech"``:
WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
The model is trained using the default parameters and code of the
`examples/pipeline_wavernn/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_wavernn>`_.
"""
if checkpoint_name not in _MODEL_CONFIG_AND_URLS:
raise ValueError(
f"Unexpected checkpoint_name: '{checkpoint_name}'. "
f"Valid choices are; {list(_MODEL_CONFIG_AND_URLS.keys())}")
url, configs = _MODEL_CONFIG_AND_URLS[checkpoint_name]
model = WaveRNN(**configs)
state_dict = load_state_dict_from_url(url, progress=False)
model.load_state_dict(state_dict)
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment