Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
217fb684
Commit
217fb684
authored
Oct 15, 2021
by
moto
Browse files
Remove factory functions of tacotron2 and wavernn (#1874)
parent
7260ad2e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
4 additions
and
197 deletions
+4
-197
docs/source/models.rst
docs/source/models.rst
+0
-29
torchaudio/models/__init__.py
torchaudio/models/__init__.py
+2
-4
torchaudio/models/tacotron2.py
torchaudio/models/tacotron2.py
+1
-116
torchaudio/models/wavernn.py
torchaudio/models/wavernn.py
+1
-48
No files found.
docs/source/models.rst
View file @
217fb684
...
@@ -28,27 +28,12 @@ DeepSpeech
...
@@ -28,27 +28,12 @@ DeepSpeech
Tacotron2
Tacotron2
~~~~~~~~~
~~~~~~~~~
Model
-----
Tacotron2
^^^^^^^^^
.. autoclass:: Tacotron2
.. autoclass:: Tacotron2
.. automethod:: forward
.. automethod:: forward
.. automethod:: infer
.. automethod:: infer
Factory Functions
-----------------
tacotron2
^^^^^^^^^
.. autofunction:: tacotron2
Wav2Letter
Wav2Letter
~~~~~~~~~~
~~~~~~~~~~
...
@@ -131,26 +116,12 @@ import_fairseq_model
...
@@ -131,26 +116,12 @@ import_fairseq_model
WaveRNN
WaveRNN
~~~~~~~
~~~~~~~
Model
-----
WaveRNN
^^^^^^^
.. autoclass:: WaveRNN
.. autoclass:: WaveRNN
.. automethod:: forward
.. automethod:: forward
.. automethod:: infer
.. automethod:: infer
Factory Functions
-----------------
wavernn
^^^^^^^
.. autofunction:: wavernn
References
References
~~~~~~~~~~
~~~~~~~~~~
...
...
torchaudio/models/__init__.py
View file @
217fb684
from
.wav2letter
import
Wav2Letter
from
.wav2letter
import
Wav2Letter
from
.wavernn
import
WaveRNN
,
wavernn
from
.wavernn
import
WaveRNN
from
.conv_tasnet
import
ConvTasNet
from
.conv_tasnet
import
ConvTasNet
from
.deepspeech
import
DeepSpeech
from
.deepspeech
import
DeepSpeech
from
.tacotron2
import
Tacotron2
,
tacotron2
from
.tacotron2
import
Tacotron2
from
.wav2vec2
import
(
from
.wav2vec2
import
(
Wav2Vec2Model
,
Wav2Vec2Model
,
wav2vec2_model
,
wav2vec2_model
,
...
@@ -17,7 +17,6 @@ from .wav2vec2 import (
...
@@ -17,7 +17,6 @@ from .wav2vec2 import (
__all__
=
[
__all__
=
[
'Wav2Letter'
,
'Wav2Letter'
,
'WaveRNN'
,
'WaveRNN'
,
'wavernn'
,
'ConvTasNet'
,
'ConvTasNet'
,
'DeepSpeech'
,
'DeepSpeech'
,
'Wav2Vec2Model'
,
'Wav2Vec2Model'
,
...
@@ -29,5 +28,4 @@ __all__ = [
...
@@ -29,5 +28,4 @@ __all__ = [
'hubert_large'
,
'hubert_large'
,
'hubert_xlarge'
,
'hubert_xlarge'
,
'Tacotron2'
,
'Tacotron2'
,
'tacotron2'
,
]
]
torchaudio/models/tacotron2.py
View file @
217fb684
...
@@ -27,77 +27,19 @@
...
@@ -27,77 +27,19 @@
import
warnings
import
warnings
from
math
import
sqrt
from
math
import
sqrt
from
typing
import
Tuple
,
List
,
Optional
,
Union
,
Any
,
Dict
from
typing
import
Tuple
,
List
,
Optional
,
Union
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
torch.hub
import
load_state_dict_from_url
__all__
=
[
__all__
=
[
"Tacotron2"
,
"Tacotron2"
,
"tacotron2"
,
]
]
_DEFAULT_PARAMETERS
=
{
'mask_padding'
:
False
,
'n_mels'
:
80
,
'n_frames_per_step'
:
1
,
'symbol_embedding_dim'
:
512
,
'encoder_embedding_dim'
:
512
,
'encoder_n_convolution'
:
3
,
'encoder_kernel_size'
:
5
,
'decoder_rnn_dim'
:
1024
,
'decoder_max_step'
:
2000
,
'decoder_dropout'
:
0.1
,
'decoder_early_stopping'
:
True
,
'attention_rnn_dim'
:
1024
,
'attention_hidden_dim'
:
128
,
'attention_location_n_filter'
:
32
,
'attention_location_kernel_size'
:
31
,
'attention_dropout'
:
0.1
,
'prenet_dim'
:
256
,
'postnet_n_convolution'
:
5
,
'postnet_kernel_size'
:
5
,
'postnet_embedding_dim'
:
512
,
'gate_threshold'
:
0.5
,
}
_MODEL_CONFIG_AND_URLS
:
Dict
[
str
,
Tuple
[
str
,
Dict
[
str
,
Any
]]]
=
{
'tacotron2_english_characters_1500_epochs_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_ljspeech.pth'
,
dict
(
n_symbol
=
38
,
**
_DEFAULT_PARAMETERS
,
)
),
'tacotron2_english_characters_1500_epochs_wavernn_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth'
,
dict
(
n_symbol
=
38
,
**
_DEFAULT_PARAMETERS
,
)
),
'tacotron2_english_phonemes_1500_epochs_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_ljspeech.pth'
,
dict
(
n_symbol
=
96
,
**
_DEFAULT_PARAMETERS
,
)
),
'tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth'
,
dict
(
n_symbol
=
96
,
**
_DEFAULT_PARAMETERS
,
)
)
}
def
_get_linear_layer
(
def
_get_linear_layer
(
in_dim
:
int
,
out_dim
:
int
,
bias
:
bool
=
True
,
w_init_gain
:
str
=
"linear"
in_dim
:
int
,
out_dim
:
int
,
bias
:
bool
=
True
,
w_init_gain
:
str
=
"linear"
)
->
torch
.
nn
.
Linear
:
)
->
torch
.
nn
.
Linear
:
...
@@ -1165,60 +1107,3 @@ class Tacotron2(nn.Module):
...
@@ -1165,60 +1107,3 @@ class Tacotron2(nn.Module):
alignments
=
alignments
.
unfold
(
1
,
n_batch
,
n_batch
).
transpose
(
0
,
2
)
alignments
=
alignments
.
unfold
(
1
,
n_batch
,
n_batch
).
transpose
(
0
,
2
)
return
mel_outputs_postnet
,
mel_specgram_lengths
,
alignments
return
mel_outputs_postnet
,
mel_specgram_lengths
,
alignments
def
tacotron2
(
checkpoint_name
:
str
)
->
Tacotron2
:
r
"""Get pretrained Tacotron2 model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"tacotron2_english_characters_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_
with default parameters.
- ``"tacotron2_english_characters_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
For the parameters, the `win_length` is set to 1100, `hop_length` to 275,
`n_fft` to 2048, `mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
- ``"tacotron2_english_phonemes_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`.
- ``"tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`,
`win_length` is set to 1100, `hop_length` to 275, `n_fft` to 2048,
`mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
"""
if
checkpoint_name
not
in
_MODEL_CONFIG_AND_URLS
:
raise
ValueError
(
f
"Unexpected checkpoint_name: '
{
checkpoint_name
}
'. "
f
"Valid choices are;
{
list
(
_MODEL_CONFIG_AND_URLS
.
keys
())
}
"
)
url
,
configs
=
_MODEL_CONFIG_AND_URLS
[
checkpoint_name
]
model
=
Tacotron2
(
**
configs
)
state_dict
=
load_state_dict_from_url
(
url
,
progress
=
False
)
model
.
load_state_dict
(
state_dict
)
return
model
torchaudio/models/wavernn.py
View file @
217fb684
from
typing
import
List
,
Tuple
,
Dict
,
Any
,
Optional
from
typing
import
List
,
Tuple
,
Optional
import
math
import
math
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.hub
import
load_state_dict_from_url
__all__
=
[
__all__
=
[
"ResBlock"
,
"ResBlock"
,
...
@@ -14,29 +12,9 @@ __all__ = [
...
@@ -14,29 +12,9 @@ __all__ = [
"Stretch2d"
,
"Stretch2d"
,
"UpsampleNetwork"
,
"UpsampleNetwork"
,
"WaveRNN"
,
"WaveRNN"
,
"wavernn"
,
]
]
_MODEL_CONFIG_AND_URLS
:
Dict
[
str
,
Tuple
[
str
,
Dict
[
str
,
Any
]]]
=
{
'wavernn_10k_epochs_8bits_ljspeech'
:
(
'https://download.pytorch.org/models/audio/wavernn_10k_epochs_8bits_ljspeech.pth'
,
{
'upsample_scales'
:
[
5
,
5
,
11
],
'n_classes'
:
2
**
8
,
# n_bits = 8
'hop_length'
:
275
,
'n_res_block'
:
10
,
'n_rnn'
:
512
,
'n_fc'
:
512
,
'kernel_size'
:
5
,
'n_freq'
:
80
,
'n_hidden'
:
128
,
'n_output'
:
128
}
)
}
class
ResBlock
(
nn
.
Module
):
class
ResBlock
(
nn
.
Module
):
r
"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
r
"""ResNet block based on *Efficient Neural Audio Synthesis* [:footcite:`kalchbrenner2018efficient`].
...
@@ -424,28 +402,3 @@ class WaveRNN(nn.Module):
...
@@ -424,28 +402,3 @@ class WaveRNN(nn.Module):
output
.
append
(
x
)
output
.
append
(
x
)
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
),
lengths
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
),
lengths
def
wavernn
(
checkpoint_name
:
str
)
->
WaveRNN
:
r
"""Get pretrained WaveRNN model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"wavernn_10k_epochs_8bits_ljspeech"``:
WaveRNN model trained with 10k epochs and 8 bits depth waveform on the LJSpeech dataset.
The model is trained using the default parameters and code of the
`examples/pipeline_wavernn/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_wavernn>`_.
"""
if
checkpoint_name
not
in
_MODEL_CONFIG_AND_URLS
:
raise
ValueError
(
f
"Unexpected checkpoint_name: '
{
checkpoint_name
}
'. "
f
"Valid choices are;
{
list
(
_MODEL_CONFIG_AND_URLS
.
keys
())
}
"
)
url
,
configs
=
_MODEL_CONFIG_AND_URLS
[
checkpoint_name
]
model
=
WaveRNN
(
**
configs
)
state_dict
=
load_state_dict_from_url
(
url
,
progress
=
False
)
model
.
load_state_dict
(
state_dict
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment