Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
15bc554f
"vscode:/vscode.git/clone" did not exist on "5fb96255dcaea27b70ac7cc86311503fa506a5bf"
Unverified
Commit
15bc554f
authored
Aug 10, 2021
by
yangarbiter
Committed by
GitHub
Aug 10, 2021
Browse files
Add Tacotron2 inference method (#1648)
parent
90c0edc5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
318 additions
and
13 deletions
+318
-13
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
+157
-10
torchaudio/prototype/tacotron2.py
torchaudio/prototype/tacotron2.py
+161
-3
No files found.
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
View file @
15bc554f
from
typing
import
Tuple
import
torch
from
torch
import
Tensor
from
torchaudio.prototype.tacotron2
import
Tacotron2
,
_Encoder
,
_Decoder
from
torchaudio_unittest.common_utils
import
(
TestBaseMixin
,
...
...
@@ -6,6 +8,26 @@ from torchaudio_unittest.common_utils import (
)
class
Tacotron2InferenceWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
model
=
model
def
forward
(
self
,
text
:
Tensor
,
text_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
return
self
.
model
.
infer
(
text
,
text_lengths
)
class
Tacotron2DecoderInferenceWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
model
=
model
def
forward
(
self
,
memory
:
Tensor
,
memory_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
return
self
.
model
.
infer
(
memory
,
memory_lengths
)
class
TorchscriptConsistencyMixin
(
TempDirMixin
):
r
"""Mixin to provide easy access assert torchscript consistency"""
...
...
@@ -24,6 +46,7 @@ class TorchscriptConsistencyMixin(TempDirMixin):
class
Tacotron2EncoderTests
(
TestBaseMixin
,
TorchscriptConsistencyMixin
):
def
test_tacotron2_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Encoder."""
n_batch
,
n_seq
,
encoder_embedding_dim
=
16
,
64
,
512
...
...
@@ -60,27 +83,29 @@ class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
assert
out
.
size
()
==
(
n_batch
,
n_seq
,
encoder_embedding_dim
)
def
_get_decoder_model
(
n_mels
=
80
,
encoder_embedding_dim
=
512
):
def
_get_decoder_model
(
n_mels
=
80
,
encoder_embedding_dim
=
512
,
decoder_max_step
=
2000
,
gate_threshold
=
0.5
):
model
=
_Decoder
(
n_mels
=
n_mels
,
n_frames_per_step
=
1
,
encoder_embedding_dim
=
encoder_embedding_dim
,
decoder_rnn_dim
=
1024
,
decoder_max_step
=
2000
,
decoder_max_step
=
decoder_max_step
,
decoder_dropout
=
0.1
,
decoder_early_stopping
=
Fals
e
,
decoder_early_stopping
=
Tru
e
,
attention_rnn_dim
=
1024
,
attention_hidden_dim
=
128
,
attention_location_n_filter
=
32
,
attention_location_kernel_size
=
31
,
attention_dropout
=
0.1
,
prenet_dim
=
256
,
gate_threshold
=
0.5
,
gate_threshold
=
gate_threshold
,
)
return
model
class
Tacotron2DecoderTests
(
TestBaseMixin
,
TorchscriptConsistencyMixin
):
def
test_decoder_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Decoder."""
n_batch
=
16
...
...
@@ -125,16 +150,81 @@ class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
)
memory_lengths
=
torch
.
ones
(
n_batch
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
mel_
outputs
,
gate_outputs
,
alignments
=
model
(
mel_
specgram
,
gate_outputs
,
alignments
=
model
(
memory
,
decoder_inputs
,
memory_lengths
)
assert
mel_
outputs
.
size
()
==
(
n_batch
,
n_mels
,
n_time_steps
)
assert
mel_
specgram
.
size
()
==
(
n_batch
,
n_mels
,
n_time_steps
)
assert
gate_outputs
.
size
()
==
(
n_batch
,
n_time_steps
)
assert
alignments
.
size
()
==
(
n_batch
,
n_time_steps
,
n_seq
)
def
test_decoder_inference_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Decoder."""
n_batch
=
16
n_mels
=
80
n_seq
=
200
encoder_embedding_dim
=
256
decoder_max_step
=
300
# make inference more efficient
gate_threshold
=
0.505
# make inference more efficient
model
=
_get_decoder_model
(
n_mels
=
n_mels
,
encoder_embedding_dim
=
encoder_embedding_dim
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
,
)
model
=
model
.
to
(
self
.
device
).
eval
()
memory
=
torch
.
rand
(
n_batch
,
n_seq
,
encoder_embedding_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
memory_lengths
=
torch
.
ones
(
n_batch
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_get_tacotron2_model
(
n_mels
):
model_wrapper
=
Tacotron2DecoderInferenceWrapper
(
model
)
self
.
_assert_torchscript_consistency
(
model_wrapper
,
(
memory
,
memory_lengths
))
def
test_decoder_inference_output_shape
(
self
):
r
"""Validate the torchscript consistency of a Decoder."""
n_batch
=
16
n_mels
=
80
n_seq
=
200
encoder_embedding_dim
=
256
decoder_max_step
=
300
# make inference more efficient
gate_threshold
=
0.505
# if set to 0.5, the model will only run one step
model
=
_get_decoder_model
(
n_mels
=
n_mels
,
encoder_embedding_dim
=
encoder_embedding_dim
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
,
)
model
=
model
.
to
(
self
.
device
).
eval
()
memory
=
torch
.
rand
(
n_batch
,
n_seq
,
encoder_embedding_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
memory_lengths
=
torch
.
ones
(
n_batch
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
mel_specgram
,
mel_specgram_lengths
,
gate_outputs
,
alignments
=
model
.
infer
(
memory
,
memory_lengths
)
assert
len
(
mel_specgram
.
size
())
==
3
assert
mel_specgram
.
size
()[:
-
1
]
==
(
n_batch
,
n_mels
,
)
assert
mel_specgram
.
size
()[
2
]
==
mel_specgram_lengths
.
max
().
item
()
assert
len
(
mel_specgram_lengths
.
size
())
==
1
assert
mel_specgram_lengths
.
size
()[
0
]
==
n_batch
assert
mel_specgram_lengths
.
max
().
item
()
<=
model
.
decoder_max_step
assert
len
(
gate_outputs
.
size
())
==
2
assert
gate_outputs
.
size
()[
0
]
==
n_batch
assert
gate_outputs
.
size
()[
1
]
==
mel_specgram_lengths
.
max
().
item
()
assert
len
(
alignments
.
size
())
==
2
assert
alignments
.
size
()[
0
]
==
n_seq
assert
alignments
.
size
()[
1
]
==
mel_specgram_lengths
.
max
().
item
()
*
n_batch
def
_get_tacotron2_model
(
n_mels
,
decoder_max_step
=
2000
,
gate_threshold
=
0.5
):
return
Tacotron2
(
mask_padding
=
False
,
n_mels
=
n_mels
,
...
...
@@ -145,7 +235,7 @@ def _get_tacotron2_model(n_mels):
encoder_n_convolution
=
3
,
encoder_kernel_size
=
5
,
decoder_rnn_dim
=
1024
,
decoder_max_step
=
2000
,
decoder_max_step
=
decoder_max_step
,
decoder_dropout
=
0.1
,
decoder_early_stopping
=
True
,
attention_rnn_dim
=
1024
,
...
...
@@ -157,13 +247,14 @@ def _get_tacotron2_model(n_mels):
postnet_n_convolution
=
5
,
postnet_kernel_size
=
5
,
postnet_embedding_dim
=
512
,
gate_threshold
=
0.5
,
gate_threshold
=
gate_threshold
,
)
class
Tacotron2Tests
(
TestBaseMixin
,
TorchscriptConsistencyMixin
):
def
_get_inputs
(
self
,
n_mels
,
n_batch
:
int
,
max_mel_specgram_length
:
int
,
max_text_length
:
int
self
,
n_mels
:
int
,
n_batch
:
int
,
max_mel_specgram_length
:
int
,
max_text_length
:
int
):
text
=
torch
.
randint
(
0
,
148
,
(
n_batch
,
max_text_length
),
dtype
=
torch
.
int32
,
device
=
self
.
device
...
...
@@ -236,3 +327,59 @@ class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin):
mel_out
.
sum
().
backward
(
retain_graph
=
True
)
mel_out_postnet
.
sum
().
backward
(
retain_graph
=
True
)
gate_outputs
.
sum
().
backward
()
def
_get_inference_inputs
(
self
,
n_batch
:
int
,
max_text_length
:
int
):
text
=
torch
.
randint
(
0
,
148
,
(
n_batch
,
max_text_length
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
text_lengths
=
max_text_length
*
torch
.
ones
(
(
n_batch
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
return
text
,
text_lengths
def
test_tacotron2_inference_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of Tacotron2 inference function."""
n_batch
=
16
n_mels
=
40
max_text_length
=
100
decoder_max_step
=
200
# make inference more efficient
gate_threshold
=
0.51
# if set to 0.5, the model will only run one step
model
=
_get_tacotron2_model
(
n_mels
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
).
to
(
self
.
device
).
eval
()
inputs
=
self
.
_get_inference_inputs
(
n_batch
,
max_text_length
)
model_wrapper
=
Tacotron2InferenceWrapper
(
model
)
self
.
_assert_torchscript_consistency
(
model_wrapper
,
inputs
)
def
test_tacotron2_inference_output_shape
(
self
):
r
"""Feed tensors with specific shape to Tacotron2 inference function and validate
that it outputs with a tensor with expected shape.
"""
n_batch
=
16
n_mels
=
40
max_text_length
=
100
decoder_max_step
=
200
# make inference more efficient
gate_threshold
=
0.51
# if set to 0.5, the model will only run one step
model
=
_get_tacotron2_model
(
n_mels
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
).
to
(
self
.
device
).
eval
()
inputs
=
self
.
_get_inference_inputs
(
n_batch
,
max_text_length
)
mel_out
,
mel_specgram_lengths
,
alignments
=
model
.
infer
(
*
inputs
)
# There is no guarantee on exactly what max_mel_specgram_length should be
# We only know that it should be smaller than model.decoder.decoder_max_step
assert
len
(
mel_out
.
size
())
==
3
assert
mel_out
.
size
()[:
2
]
==
(
n_batch
,
n_mels
,
)
assert
mel_out
.
size
()[
2
]
==
mel_specgram_lengths
.
max
().
item
()
assert
len
(
mel_specgram_lengths
.
size
())
==
1
assert
mel_specgram_lengths
.
size
()[
0
]
==
n_batch
assert
mel_specgram_lengths
.
max
().
item
()
<=
model
.
decoder
.
decoder_max_step
assert
len
(
alignments
.
size
())
==
3
assert
alignments
.
size
()[
0
]
==
n_batch
assert
alignments
.
size
()[
1
]
==
mel_specgram_lengths
.
max
().
item
()
assert
alignments
.
size
()[
2
]
==
max_text_length
torchaudio/prototype/tacotron2.py
View file @
15bc554f
...
...
@@ -25,6 +25,7 @@
#
# *****************************************************************************
import
warnings
from
math
import
sqrt
from
typing
import
Tuple
,
List
,
Optional
,
Union
...
...
@@ -614,12 +615,12 @@ class _Decoder(nn.Module):
return
decoder_inputs
def
_parse_decoder_outputs
(
self
,
mel_
outputs
:
Tensor
,
gate_outputs
:
Tensor
,
alignments
:
Tensor
self
,
mel_
specgram
:
Tensor
,
gate_outputs
:
Tensor
,
alignments
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Prepares decoder outputs for output
Args:
mel_
outputs
(Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
mel_
specgram
(Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
alignments (Tensor): sequence of attention weights from the decoder
with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)
...
...
@@ -636,7 +637,7 @@ class _Decoder(nn.Module):
# (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max())
gate_outputs
=
gate_outputs
.
transpose
(
0
,
1
).
contiguous
()
# (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels)
mel_specgram
=
mel_
outputs
.
transpose
(
0
,
1
).
contiguous
()
mel_specgram
=
mel_
specgram
.
transpose
(
0
,
1
).
contiguous
()
# decouple frames per step
shape
=
(
mel_specgram
.
shape
[
0
],
-
1
,
self
.
n_mels
)
mel_specgram
=
mel_specgram
.
view
(
*
shape
)
...
...
@@ -805,6 +806,128 @@ class _Decoder(nn.Module):
return
mel_specgram
,
gate_outputs
,
alignments
def
_get_go_frame
(
self
,
memory
:
Tensor
)
->
Tensor
:
"""Gets all zeros frames to use as the first decoder input
args:
memory (Tensor): Encoder outputs
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
returns:
decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
"""
n_batch
=
memory
.
size
(
0
)
dtype
=
memory
.
dtype
device
=
memory
.
device
decoder_input
=
torch
.
zeros
(
n_batch
,
self
.
n_mels
*
self
.
n_frames_per_step
,
dtype
=
dtype
,
device
=
device
)
return
decoder_input
@
torch
.
jit
.
export
def
infer
(
self
,
memory
:
Tensor
,
memory_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
"""Decoder inference
Args:
memory (Tensor): Encoder outputs
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
memory_lengths (Tensor): Encoder output lengths for attention masking
(the same as ``text_lengths``) with shape (n_batch, ).
Returns:
mel_specgram (Tensor): Predicted mel spectrogram
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
gate_outputs (Tensor): Predicted stop token for each timestep
with shape (n_batch, max of ``mel_specgram_lengths``).
alignments (Tensor): Sequence of attention weights from the decoder
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
decoder_input
=
self
.
_get_go_frame
(
memory
)
mask
=
_get_mask_from_lengths
(
memory_lengths
)
(
attention_hidden
,
attention_cell
,
decoder_hidden
,
decoder_cell
,
attention_weights
,
attention_weights_cum
,
attention_context
,
processed_memory
,
)
=
self
.
_initialize_decoder_states
(
memory
)
mel_specgram_lengths
=
torch
.
ones
(
[
memory
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
memory
.
device
)
not_finished
=
torch
.
ones
(
[
memory
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
memory
.
device
)
mel_specgrams
,
gate_outputs
,
alignments
=
(
torch
.
zeros
(
1
,
dtype
=
memory
.
dtype
),
torch
.
zeros
(
1
,
dtype
=
memory
.
dtype
),
torch
.
zeros
(
1
,
dtype
=
memory
.
dtype
),
)
first_iter
=
True
while
True
:
decoder_input
=
self
.
prenet
(
decoder_input
)
(
mel_specgram
,
gate_output
,
attention_hidden
,
attention_cell
,
decoder_hidden
,
decoder_cell
,
attention_weights
,
attention_weights_cum
,
attention_context
,
)
=
self
.
decode
(
decoder_input
,
attention_hidden
,
attention_cell
,
decoder_hidden
,
decoder_cell
,
attention_weights
,
attention_weights_cum
,
attention_context
,
memory
,
processed_memory
,
mask
,
)
if
first_iter
:
mel_specgrams
=
mel_specgram
.
unsqueeze
(
0
)
gate_outputs
=
gate_output
.
transpose
(
0
,
1
)
alignments
=
attention_weights
first_iter
=
False
else
:
mel_specgrams
=
torch
.
cat
((
mel_specgrams
,
mel_specgram
.
unsqueeze
(
0
)),
dim
=
0
)
gate_outputs
=
torch
.
cat
((
gate_outputs
,
gate_output
.
transpose
(
0
,
1
)),
dim
=
0
)
alignments
=
torch
.
cat
((
alignments
,
attention_weights
),
dim
=
0
)
dec
=
torch
.
le
(
torch
.
sigmoid
(
gate_output
),
self
.
gate_threshold
).
to
(
torch
.
int32
).
squeeze
(
1
)
not_finished
=
not_finished
*
dec
if
self
.
decoder_early_stopping
and
torch
.
sum
(
not_finished
)
==
0
:
break
if
len
(
mel_specgrams
)
==
self
.
decoder_max_step
:
warnings
.
warn
(
"Reached max decoder steps"
)
break
mel_specgram_lengths
+=
not_finished
decoder_input
=
mel_specgram
mel_specgrams
,
gate_outputs
,
alignments
=
self
.
_parse_decoder_outputs
(
mel_specgrams
,
gate_outputs
,
alignments
)
return
mel_specgrams
,
mel_specgram_lengths
,
gate_outputs
,
alignments
class
Tacotron2
(
nn
.
Module
):
r
"""Tacotron2 model based on the implementation from
...
...
@@ -947,3 +1070,38 @@ class Tacotron2(nn.Module):
gate_outputs
.
masked_fill_
(
mask
[:,
0
,
:],
1e3
)
return
mel_specgram
,
mel_specgram_postnet
,
gate_outputs
,
alignments
@
torch
.
jit
.
export
def
infer
(
self
,
text
:
Tensor
,
text_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (text) and its corresponding lengths (text_lengths). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `text` should be padded with zeros to length max of ``text_lengths``.
Args:
text (Tensor): the input text to Tacotron2. (n_batch, max of ``text_lengths``)
text_lengths (Tensor): the length of each text (n_batch)
Return:
mel_specgram (Tensor): the predicted mel spectrogram
with shape (n_batch, n_mels, max of ``mel_specgram_lengths.max()``)
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
alignments (Tensor): Sequence of attention weights from the decoder.
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
embedded_inputs
=
self
.
embedding
(
text
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
text_lengths
)
mel_specgram
,
mel_specgram_lengths
,
_
,
alignments
=
self
.
decoder
.
infer
(
encoder_outputs
,
text_lengths
)
mel_outputs_postnet
=
self
.
postnet
(
mel_specgram
)
mel_outputs_postnet
=
mel_specgram
+
mel_outputs_postnet
n_batch
=
mel_outputs_postnet
.
size
(
0
)
alignments
=
alignments
.
unfold
(
1
,
n_batch
,
n_batch
).
transpose
(
0
,
2
)
return
mel_outputs_postnet
,
mel_specgram_lengths
,
alignments
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment