Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
2e1df525
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "75c09e1ffe9d449e9fe6cfff59a5fcc4e86b46db"
Unverified
Commit
2e1df525
authored
Aug 11, 2021
by
yangarbiter
Committed by
GitHub
Aug 11, 2021
Browse files
Add tacotron2 pretrained models (#1693)
parent
6e0af713
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
116 additions
and
1 deletion
+116
-1
torchaudio/prototype/tacotron2.py
torchaudio/prototype/tacotron2.py
+116
-1
No files found.
torchaudio/prototype/tacotron2.py
View file @
2e1df525
...
@@ -27,19 +27,77 @@
...
@@ -27,19 +27,77 @@
import
warnings
import
warnings
from
math
import
sqrt
from
math
import
sqrt
from
typing
import
Tuple
,
List
,
Optional
,
Union
from
typing
import
Tuple
,
List
,
Optional
,
Union
,
Any
,
Dict
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
torch
import
Tensor
from
torch
import
Tensor
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
torch.hub
import
load_state_dict_from_url
__all__
=
[
__all__
=
[
"Tacotron2"
,
"Tacotron2"
,
"tacotron2"
,
]
]
_DEFAULT_PARAMETERS
=
{
'mask_padding'
:
False
,
'n_mels'
:
80
,
'n_frames_per_step'
:
1
,
'symbol_embedding_dim'
:
512
,
'encoder_embedding_dim'
:
512
,
'encoder_n_convolution'
:
3
,
'encoder_kernel_size'
:
5
,
'decoder_rnn_dim'
:
1024
,
'decoder_max_step'
:
2000
,
'decoder_dropout'
:
0.1
,
'decoder_early_stopping'
:
True
,
'attention_rnn_dim'
:
1024
,
'attention_hidden_dim'
:
128
,
'attention_location_n_filter'
:
32
,
'attention_location_kernel_size'
:
31
,
'attention_dropout'
:
0.1
,
'prenet_dim'
:
256
,
'postnet_n_convolution'
:
5
,
'postnet_kernel_size'
:
5
,
'postnet_embedding_dim'
:
512
,
'gate_threshold'
:
0.5
,
}
_MODEL_CONFIG_AND_URLS
:
Dict
[
str
,
Tuple
[
str
,
Dict
[
str
,
Any
]]]
=
{
'tacotron2_english_characters_1500_epochs_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_ljspeech.pth'
,
dict
(
n_symbol
=
38
,
**
_DEFAULT_PARAMETERS
,
)
),
'tacotron2_english_characters_1500_epochs_wavernn_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_characters_1500_epochs_wavernn_ljspeech.pth'
,
dict
(
n_symbol
=
38
,
**
_DEFAULT_PARAMETERS
,
)
),
'tacotron2_english_phonemes_1500_epochs_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_ljspeech.pth'
,
dict
(
n_symbol
=
96
,
**
_DEFAULT_PARAMETERS
,
)
),
'tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech'
:
(
'https://download.pytorch.org/models/audio/tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech.pth'
,
dict
(
n_symbol
=
96
,
**
_DEFAULT_PARAMETERS
,
)
)
}
def
_get_linear_layer
(
def
_get_linear_layer
(
in_dim
:
int
,
out_dim
:
int
,
bias
:
bool
=
True
,
w_init_gain
:
str
=
"linear"
in_dim
:
int
,
out_dim
:
int
,
bias
:
bool
=
True
,
w_init_gain
:
str
=
"linear"
)
->
torch
.
nn
.
Linear
:
)
->
torch
.
nn
.
Linear
:
...
@@ -1105,3 +1163,60 @@ class Tacotron2(nn.Module):
...
@@ -1105,3 +1163,60 @@ class Tacotron2(nn.Module):
alignments
=
alignments
.
unfold
(
1
,
n_batch
,
n_batch
).
transpose
(
0
,
2
)
alignments
=
alignments
.
unfold
(
1
,
n_batch
,
n_batch
).
transpose
(
0
,
2
)
return
mel_outputs_postnet
,
mel_specgram_lengths
,
alignments
return
mel_outputs_postnet
,
mel_specgram_lengths
,
alignments
def
tacotron2
(
checkpoint_name
:
str
)
->
Tacotron2
:
r
"""Get pretrained Tacotron2 model.
Args:
checkpoint_name (str): The name of the checkpoint to load. Available checkpoints:
- ``"tacotron2_english_characters_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_
with default parameters.
- ``"tacotron2_english_characters_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
For the parameters, the `win_length` is set to 1100, `hop_length` to 275,
`n_fft` to 2048, `mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
- ``"tacotron2_english_phonemes_1500_epochs_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`.
- ``"tacotron2_english_phonemes_1500_epochs_wavernn_ljspeech"``:
Tacotron2 model trained with english characters as the input, with 1500 epochs,
and on the LJSpeech dataset.
The model is trained using the code of `examples/pipeline_tacotron2/main.py
<https://github.com/pytorch/audio/tree/master/examples/pipeline_tacotron2>`_.
The text preprocessor is set to the `"english_phonemes"`,
`win_length` is set to 1100, `hop_length` to 275, `n_fft` to 2048,
`mel_fmin` to 40, and `mel_fmax` to 11025.
The audio settings here matches the audio settings used for the pretrained
checkpoint name `"wavernn_10k_epochs_8bits_ljspeech"` for WaveRNN.
"""
if
checkpoint_name
not
in
_MODEL_CONFIG_AND_URLS
:
raise
ValueError
(
f
"Unexpected checkpoint_name: '
{
checkpoint_name
}
'. "
f
"Valid choices are;
{
list
(
_MODEL_CONFIG_AND_URLS
.
keys
())
}
"
)
url
,
configs
=
_MODEL_CONFIG_AND_URLS
[
checkpoint_name
]
model
=
Tacotron2
(
**
configs
)
state_dict
=
load_state_dict_from_url
(
url
,
progress
=
False
)
model
.
load_state_dict
(
state_dict
)
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment