Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
15bc554f
Unverified
Commit
15bc554f
authored
Aug 10, 2021
by
yangarbiter
Committed by
GitHub
Aug 10, 2021
Browse files
Add Tacotron2 inference method (#1648)
parent
90c0edc5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
318 additions
and
13 deletions
+318
-13
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
+157
-10
torchaudio/prototype/tacotron2.py
torchaudio/prototype/tacotron2.py
+161
-3
No files found.
test/torchaudio_unittest/models/tacotron2/model_test_impl.py
View file @
15bc554f
from
typing
import
Tuple
import
torch
from
torch
import
Tensor
from
torchaudio.prototype.tacotron2
import
Tacotron2
,
_Encoder
,
_Decoder
from
torchaudio_unittest.common_utils
import
(
TestBaseMixin
,
...
...
@@ -6,6 +8,26 @@ from torchaudio_unittest.common_utils import (
)
class
Tacotron2InferenceWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
model
=
model
def
forward
(
self
,
text
:
Tensor
,
text_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
return
self
.
model
.
infer
(
text
,
text_lengths
)
class
Tacotron2DecoderInferenceWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
model
):
super
().
__init__
()
self
.
model
=
model
def
forward
(
self
,
memory
:
Tensor
,
memory_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
return
self
.
model
.
infer
(
memory
,
memory_lengths
)
class
TorchscriptConsistencyMixin
(
TempDirMixin
):
r
"""Mixin to provide easy access assert torchscript consistency"""
...
...
@@ -24,6 +46,7 @@ class TorchscriptConsistencyMixin(TempDirMixin):
class
Tacotron2EncoderTests
(
TestBaseMixin
,
TorchscriptConsistencyMixin
):
def
test_tacotron2_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Encoder."""
n_batch
,
n_seq
,
encoder_embedding_dim
=
16
,
64
,
512
...
...
@@ -60,27 +83,29 @@ class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
assert
out
.
size
()
==
(
n_batch
,
n_seq
,
encoder_embedding_dim
)
def
_get_decoder_model
(
n_mels
=
80
,
encoder_embedding_dim
=
512
):
def
_get_decoder_model
(
n_mels
=
80
,
encoder_embedding_dim
=
512
,
decoder_max_step
=
2000
,
gate_threshold
=
0.5
):
model
=
_Decoder
(
n_mels
=
n_mels
,
n_frames_per_step
=
1
,
encoder_embedding_dim
=
encoder_embedding_dim
,
decoder_rnn_dim
=
1024
,
decoder_max_step
=
2000
,
decoder_max_step
=
decoder_max_step
,
decoder_dropout
=
0.1
,
decoder_early_stopping
=
Fals
e
,
decoder_early_stopping
=
Tru
e
,
attention_rnn_dim
=
1024
,
attention_hidden_dim
=
128
,
attention_location_n_filter
=
32
,
attention_location_kernel_size
=
31
,
attention_dropout
=
0.1
,
prenet_dim
=
256
,
gate_threshold
=
0.5
,
gate_threshold
=
gate_threshold
,
)
return
model
class
Tacotron2DecoderTests
(
TestBaseMixin
,
TorchscriptConsistencyMixin
):
def
test_decoder_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Decoder."""
n_batch
=
16
...
...
@@ -125,16 +150,81 @@ class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
)
memory_lengths
=
torch
.
ones
(
n_batch
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
mel_
outputs
,
gate_outputs
,
alignments
=
model
(
mel_
specgram
,
gate_outputs
,
alignments
=
model
(
memory
,
decoder_inputs
,
memory_lengths
)
assert
mel_
outputs
.
size
()
==
(
n_batch
,
n_mels
,
n_time_steps
)
assert
mel_
specgram
.
size
()
==
(
n_batch
,
n_mels
,
n_time_steps
)
assert
gate_outputs
.
size
()
==
(
n_batch
,
n_time_steps
)
assert
alignments
.
size
()
==
(
n_batch
,
n_time_steps
,
n_seq
)
def
test_decoder_inference_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of a Decoder."""
n_batch
=
16
n_mels
=
80
n_seq
=
200
encoder_embedding_dim
=
256
decoder_max_step
=
300
# make inference more efficient
gate_threshold
=
0.505
# make inference more efficient
model
=
_get_decoder_model
(
n_mels
=
n_mels
,
encoder_embedding_dim
=
encoder_embedding_dim
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
,
)
model
=
model
.
to
(
self
.
device
).
eval
()
memory
=
torch
.
rand
(
n_batch
,
n_seq
,
encoder_embedding_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
memory_lengths
=
torch
.
ones
(
n_batch
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
def
_get_tacotron2_model
(
n_mels
):
model_wrapper
=
Tacotron2DecoderInferenceWrapper
(
model
)
self
.
_assert_torchscript_consistency
(
model_wrapper
,
(
memory
,
memory_lengths
))
def
test_decoder_inference_output_shape
(
self
):
r
"""Validate the torchscript consistency of a Decoder."""
n_batch
=
16
n_mels
=
80
n_seq
=
200
encoder_embedding_dim
=
256
decoder_max_step
=
300
# make inference more efficient
gate_threshold
=
0.505
# if set to 0.5, the model will only run one step
model
=
_get_decoder_model
(
n_mels
=
n_mels
,
encoder_embedding_dim
=
encoder_embedding_dim
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
,
)
model
=
model
.
to
(
self
.
device
).
eval
()
memory
=
torch
.
rand
(
n_batch
,
n_seq
,
encoder_embedding_dim
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
memory_lengths
=
torch
.
ones
(
n_batch
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
mel_specgram
,
mel_specgram_lengths
,
gate_outputs
,
alignments
=
model
.
infer
(
memory
,
memory_lengths
)
assert
len
(
mel_specgram
.
size
())
==
3
assert
mel_specgram
.
size
()[:
-
1
]
==
(
n_batch
,
n_mels
,
)
assert
mel_specgram
.
size
()[
2
]
==
mel_specgram_lengths
.
max
().
item
()
assert
len
(
mel_specgram_lengths
.
size
())
==
1
assert
mel_specgram_lengths
.
size
()[
0
]
==
n_batch
assert
mel_specgram_lengths
.
max
().
item
()
<=
model
.
decoder_max_step
assert
len
(
gate_outputs
.
size
())
==
2
assert
gate_outputs
.
size
()[
0
]
==
n_batch
assert
gate_outputs
.
size
()[
1
]
==
mel_specgram_lengths
.
max
().
item
()
assert
len
(
alignments
.
size
())
==
2
assert
alignments
.
size
()[
0
]
==
n_seq
assert
alignments
.
size
()[
1
]
==
mel_specgram_lengths
.
max
().
item
()
*
n_batch
def
_get_tacotron2_model
(
n_mels
,
decoder_max_step
=
2000
,
gate_threshold
=
0.5
):
return
Tacotron2
(
mask_padding
=
False
,
n_mels
=
n_mels
,
...
...
@@ -145,7 +235,7 @@ def _get_tacotron2_model(n_mels):
encoder_n_convolution
=
3
,
encoder_kernel_size
=
5
,
decoder_rnn_dim
=
1024
,
decoder_max_step
=
2000
,
decoder_max_step
=
decoder_max_step
,
decoder_dropout
=
0.1
,
decoder_early_stopping
=
True
,
attention_rnn_dim
=
1024
,
...
...
@@ -157,13 +247,14 @@ def _get_tacotron2_model(n_mels):
postnet_n_convolution
=
5
,
postnet_kernel_size
=
5
,
postnet_embedding_dim
=
512
,
gate_threshold
=
0.5
,
gate_threshold
=
gate_threshold
,
)
class
Tacotron2Tests
(
TestBaseMixin
,
TorchscriptConsistencyMixin
):
def
_get_inputs
(
self
,
n_mels
,
n_batch
:
int
,
max_mel_specgram_length
:
int
,
max_text_length
:
int
self
,
n_mels
:
int
,
n_batch
:
int
,
max_mel_specgram_length
:
int
,
max_text_length
:
int
):
text
=
torch
.
randint
(
0
,
148
,
(
n_batch
,
max_text_length
),
dtype
=
torch
.
int32
,
device
=
self
.
device
...
...
@@ -236,3 +327,59 @@ class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin):
mel_out
.
sum
().
backward
(
retain_graph
=
True
)
mel_out_postnet
.
sum
().
backward
(
retain_graph
=
True
)
gate_outputs
.
sum
().
backward
()
def
_get_inference_inputs
(
self
,
n_batch
:
int
,
max_text_length
:
int
):
text
=
torch
.
randint
(
0
,
148
,
(
n_batch
,
max_text_length
),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
text_lengths
=
max_text_length
*
torch
.
ones
(
(
n_batch
,),
dtype
=
torch
.
int32
,
device
=
self
.
device
)
return
text
,
text_lengths
def
test_tacotron2_inference_torchscript_consistency
(
self
):
r
"""Validate the torchscript consistency of Tacotron2 inference function."""
n_batch
=
16
n_mels
=
40
max_text_length
=
100
decoder_max_step
=
200
# make inference more efficient
gate_threshold
=
0.51
# if set to 0.5, the model will only run one step
model
=
_get_tacotron2_model
(
n_mels
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
).
to
(
self
.
device
).
eval
()
inputs
=
self
.
_get_inference_inputs
(
n_batch
,
max_text_length
)
model_wrapper
=
Tacotron2InferenceWrapper
(
model
)
self
.
_assert_torchscript_consistency
(
model_wrapper
,
inputs
)
def
test_tacotron2_inference_output_shape
(
self
):
r
"""Feed tensors with specific shape to Tacotron2 inference function and validate
that it outputs with a tensor with expected shape.
"""
n_batch
=
16
n_mels
=
40
max_text_length
=
100
decoder_max_step
=
200
# make inference more efficient
gate_threshold
=
0.51
# if set to 0.5, the model will only run one step
model
=
_get_tacotron2_model
(
n_mels
,
decoder_max_step
=
decoder_max_step
,
gate_threshold
=
gate_threshold
).
to
(
self
.
device
).
eval
()
inputs
=
self
.
_get_inference_inputs
(
n_batch
,
max_text_length
)
mel_out
,
mel_specgram_lengths
,
alignments
=
model
.
infer
(
*
inputs
)
# There is no guarantee on exactly what max_mel_specgram_length should be
# We only know that it should be smaller than model.decoder.decoder_max_step
assert
len
(
mel_out
.
size
())
==
3
assert
mel_out
.
size
()[:
2
]
==
(
n_batch
,
n_mels
,
)
assert
mel_out
.
size
()[
2
]
==
mel_specgram_lengths
.
max
().
item
()
assert
len
(
mel_specgram_lengths
.
size
())
==
1
assert
mel_specgram_lengths
.
size
()[
0
]
==
n_batch
assert
mel_specgram_lengths
.
max
().
item
()
<=
model
.
decoder
.
decoder_max_step
assert
len
(
alignments
.
size
())
==
3
assert
alignments
.
size
()[
0
]
==
n_batch
assert
alignments
.
size
()[
1
]
==
mel_specgram_lengths
.
max
().
item
()
assert
alignments
.
size
()[
2
]
==
max_text_length
torchaudio/prototype/tacotron2.py
View file @
15bc554f
...
...
@@ -25,6 +25,7 @@
#
# *****************************************************************************
import
warnings
from
math
import
sqrt
from
typing
import
Tuple
,
List
,
Optional
,
Union
...
...
@@ -614,12 +615,12 @@ class _Decoder(nn.Module):
return
decoder_inputs
def
_parse_decoder_outputs
(
self
,
mel_
outputs
:
Tensor
,
gate_outputs
:
Tensor
,
alignments
:
Tensor
self
,
mel_
specgram
:
Tensor
,
gate_outputs
:
Tensor
,
alignments
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Prepares decoder outputs for output
Args:
mel_
outputs
(Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
mel_
specgram
(Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
alignments (Tensor): sequence of attention weights from the decoder
with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)
...
...
@@ -636,7 +637,7 @@ class _Decoder(nn.Module):
# (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max())
gate_outputs
=
gate_outputs
.
transpose
(
0
,
1
).
contiguous
()
# (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels)
mel_specgram
=
mel_
outputs
.
transpose
(
0
,
1
).
contiguous
()
mel_specgram
=
mel_
specgram
.
transpose
(
0
,
1
).
contiguous
()
# decouple frames per step
shape
=
(
mel_specgram
.
shape
[
0
],
-
1
,
self
.
n_mels
)
mel_specgram
=
mel_specgram
.
view
(
*
shape
)
...
...
@@ -805,6 +806,128 @@ class _Decoder(nn.Module):
return
mel_specgram
,
gate_outputs
,
alignments
def
_get_go_frame
(
self
,
memory
:
Tensor
)
->
Tensor
:
"""Gets all zeros frames to use as the first decoder input
args:
memory (Tensor): Encoder outputs
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
returns:
decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
"""
n_batch
=
memory
.
size
(
0
)
dtype
=
memory
.
dtype
device
=
memory
.
device
decoder_input
=
torch
.
zeros
(
n_batch
,
self
.
n_mels
*
self
.
n_frames_per_step
,
dtype
=
dtype
,
device
=
device
)
return
decoder_input
@
torch
.
jit
.
export
def
infer
(
self
,
memory
:
Tensor
,
memory_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
,
Tensor
]:
"""Decoder inference
Args:
memory (Tensor): Encoder outputs
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
memory_lengths (Tensor): Encoder output lengths for attention masking
(the same as ``text_lengths``) with shape (n_batch, ).
Returns:
mel_specgram (Tensor): Predicted mel spectrogram
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
gate_outputs (Tensor): Predicted stop token for each timestep
with shape (n_batch, max of ``mel_specgram_lengths``).
alignments (Tensor): Sequence of attention weights from the decoder
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
decoder_input
=
self
.
_get_go_frame
(
memory
)
mask
=
_get_mask_from_lengths
(
memory_lengths
)
(
attention_hidden
,
attention_cell
,
decoder_hidden
,
decoder_cell
,
attention_weights
,
attention_weights_cum
,
attention_context
,
processed_memory
,
)
=
self
.
_initialize_decoder_states
(
memory
)
mel_specgram_lengths
=
torch
.
ones
(
[
memory
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
memory
.
device
)
not_finished
=
torch
.
ones
(
[
memory
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
memory
.
device
)
mel_specgrams
,
gate_outputs
,
alignments
=
(
torch
.
zeros
(
1
,
dtype
=
memory
.
dtype
),
torch
.
zeros
(
1
,
dtype
=
memory
.
dtype
),
torch
.
zeros
(
1
,
dtype
=
memory
.
dtype
),
)
first_iter
=
True
while
True
:
decoder_input
=
self
.
prenet
(
decoder_input
)
(
mel_specgram
,
gate_output
,
attention_hidden
,
attention_cell
,
decoder_hidden
,
decoder_cell
,
attention_weights
,
attention_weights_cum
,
attention_context
,
)
=
self
.
decode
(
decoder_input
,
attention_hidden
,
attention_cell
,
decoder_hidden
,
decoder_cell
,
attention_weights
,
attention_weights_cum
,
attention_context
,
memory
,
processed_memory
,
mask
,
)
if
first_iter
:
mel_specgrams
=
mel_specgram
.
unsqueeze
(
0
)
gate_outputs
=
gate_output
.
transpose
(
0
,
1
)
alignments
=
attention_weights
first_iter
=
False
else
:
mel_specgrams
=
torch
.
cat
((
mel_specgrams
,
mel_specgram
.
unsqueeze
(
0
)),
dim
=
0
)
gate_outputs
=
torch
.
cat
((
gate_outputs
,
gate_output
.
transpose
(
0
,
1
)),
dim
=
0
)
alignments
=
torch
.
cat
((
alignments
,
attention_weights
),
dim
=
0
)
dec
=
torch
.
le
(
torch
.
sigmoid
(
gate_output
),
self
.
gate_threshold
).
to
(
torch
.
int32
).
squeeze
(
1
)
not_finished
=
not_finished
*
dec
if
self
.
decoder_early_stopping
and
torch
.
sum
(
not_finished
)
==
0
:
break
if
len
(
mel_specgrams
)
==
self
.
decoder_max_step
:
warnings
.
warn
(
"Reached max decoder steps"
)
break
mel_specgram_lengths
+=
not_finished
decoder_input
=
mel_specgram
mel_specgrams
,
gate_outputs
,
alignments
=
self
.
_parse_decoder_outputs
(
mel_specgrams
,
gate_outputs
,
alignments
)
return
mel_specgrams
,
mel_specgram_lengths
,
gate_outputs
,
alignments
class
Tacotron2
(
nn
.
Module
):
r
"""Tacotron2 model based on the implementation from
...
...
@@ -947,3 +1070,38 @@ class Tacotron2(nn.Module):
gate_outputs
.
masked_fill_
(
mask
[:,
0
,
:],
1e3
)
return
mel_specgram
,
mel_specgram_postnet
,
gate_outputs
,
alignments
@
torch
.
jit
.
export
def
infer
(
self
,
text
:
Tensor
,
text_lengths
:
Tensor
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (text) and its corresponding lengths (text_lengths). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `text` should be padded with zeros to length max of ``text_lengths``.
Args:
text (Tensor): the input text to Tacotron2. (n_batch, max of ``text_lengths``)
text_lengths (Tensor): the length of each text (n_batch)
Return:
mel_specgram (Tensor): the predicted mel spectrogram
with shape (n_batch, n_mels, max of ``mel_specgram_lengths.max()``)
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
alignments (Tensor): Sequence of attention weights from the decoder.
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
embedded_inputs
=
self
.
embedding
(
text
).
transpose
(
1
,
2
)
encoder_outputs
=
self
.
encoder
(
embedded_inputs
,
text_lengths
)
mel_specgram
,
mel_specgram_lengths
,
_
,
alignments
=
self
.
decoder
.
infer
(
encoder_outputs
,
text_lengths
)
mel_outputs_postnet
=
self
.
postnet
(
mel_specgram
)
mel_outputs_postnet
=
mel_specgram
+
mel_outputs_postnet
n_batch
=
mel_outputs_postnet
.
size
(
0
)
alignments
=
alignments
.
unfold
(
1
,
n_batch
,
n_batch
).
transpose
(
0
,
2
)
return
mel_outputs_postnet
,
mel_specgram_lengths
,
alignments
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment