Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
1b52e720
Unverified
Commit
1b52e720
authored
Jul 26, 2021
by
yangarbiter
Committed by
GitHub
Jul 26, 2021
Browse files
Add Tacotron2 loss function (#1625)
parent
37dbf29f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
242 additions
and
0 deletions
+242
-0
examples/pipeline_tacotron2/README.md
examples/pipeline_tacotron2/README.md
+1
-0
examples/pipeline_tacotron2/loss.py
examples/pipeline_tacotron2/loss.py
+82
-0
test/torchaudio_unittest/example/tacotron2/__init__.py
test/torchaudio_unittest/example/tacotron2/__init__.py
+0
-0
test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py
...dio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py
+23
-0
test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py
...dio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py
+26
-0
test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py
...chaudio_unittest/example/tacotron2/tacotron2_loss_impl.py
+110
-0
No files found.
examples/pipeline_tacotron2/README.md
0 → 100644
View file @
1b52e720
This is an example pipeline for text-to-speech using Tacotron2.
examples/pipeline_tacotron2/loss.py
0 → 100644
View file @
1b52e720
# *****************************************************************************
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of the NVIDIA CORPORATION nor the
# names of its contributors may be used to endorse or promote products
# derived from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# *****************************************************************************
from
typing
import
Tuple
from
torch
import
nn
,
Tensor
class
Tacotron2Loss
(
nn
.
Module
):
"""Tacotron2 loss function modified from:
https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/Tacotron2/tacotron2/loss_function.py
"""
def
__init__
(
self
):
super
().
__init__
()
self
.
mse_loss
=
nn
.
MSELoss
(
reduction
=
"mean"
)
self
.
bce_loss
=
nn
.
BCEWithLogitsLoss
(
reduction
=
"mean"
)
def
forward
(
self
,
model_outputs
:
Tuple
[
Tensor
,
Tensor
,
Tensor
],
targets
:
Tuple
[
Tensor
,
Tensor
],
)
->
Tuple
[
Tensor
,
Tensor
,
Tensor
]:
r
"""Pass the input through the Tacotron2 loss.
The original implementation was introduced in
*Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions*
[:footcite:`shen2018natural`].
Args:
model_outputs (tuple of three Tensors): The outputs of the
Tacotron2. These outputs should include three items:
(1) the predicted mel spectrogram before the postnet (``mel_specgram``)
with shape (batch, mel, time).
(2) predicted mel spectrogram after the postnet (``mel_specgram_postnet``)
with shape (batch, mel, time), and
(3) the stop token prediction (``gate_out``) with shape (batch, ).
targets (tuple of two Tensors): The ground truth mel spectrogram (batch, mel, time) and
stop token with shape (batch, ).
Returns:
mel_loss (Tensor): The mean MSE of the mel_specgram and ground truth mel spectrogram
with shape ``torch.Size([])``.
mel_postnet_loss (Tensor): The mean MSE of the mel_specgram_postnet and
ground truth mel spectrogram with shape ``torch.Size([])``.
gate_loss (Tensor): The mean binary cross entropy loss of
the prediction on the stop token with shape ``torch.Size([])``.
"""
mel_target
,
gate_target
=
targets
[
0
],
targets
[
1
]
gate_target
=
gate_target
.
view
(
-
1
,
1
)
mel_specgram
,
mel_specgram_postnet
,
gate_out
=
model_outputs
gate_out
=
gate_out
.
view
(
-
1
,
1
)
mel_loss
=
self
.
mse_loss
(
mel_specgram
,
mel_target
)
mel_postnet_loss
=
self
.
mse_loss
(
mel_specgram_postnet
,
mel_target
)
gate_loss
=
self
.
bce_loss
(
gate_out
,
gate_target
)
return
mel_loss
,
mel_postnet_loss
,
gate_loss
test/torchaudio_unittest/example/tacotron2/__init__.py
0 → 100644
View file @
1b52e720
test/torchaudio_unittest/example/tacotron2/tacotron2_loss_cpu_test.py
0 → 100644
View file @
1b52e720
import
torch
from
.tacotron2_loss_impl
import
(
Tacotron2LossShapeTests
,
Tacotron2LossTorchscriptTests
,
Tacotron2LossGradcheckTests
,
)
from
torchaudio_unittest.common_utils
import
PytorchTestCase
class
TestTacotron2LossShapeFloat32CPU
(
PytorchTestCase
,
Tacotron2LossShapeTests
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
"cpu"
)
class
TestTacotron2TorchsciptFloat32CPU
(
PytorchTestCase
,
Tacotron2LossTorchscriptTests
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
"cpu"
)
class
TestTacotron2GradcheckFloat64CPU
(
PytorchTestCase
,
Tacotron2LossGradcheckTests
):
dtype
=
torch
.
float64
# gradcheck needs a higher numerical accuracy
device
=
torch
.
device
(
"cpu"
)
test/torchaudio_unittest/example/tacotron2/tacotron2_loss_gpu_test.py
0 → 100644
View file @
1b52e720
import
torch
from
.tacotron2_loss_impl
import
(
Tacotron2LossShapeTests
,
Tacotron2LossTorchscriptTests
,
Tacotron2LossGradcheckTests
,
)
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
@
skipIfNoCuda
class
TestTacotron2LossShapeFloat32CUDA
(
PytorchTestCase
,
Tacotron2LossShapeTests
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
"cuda"
)
@
skipIfNoCuda
class
TestTacotron2TorchsciptFloat32CUDA
(
PytorchTestCase
,
Tacotron2LossTorchscriptTests
):
dtype
=
torch
.
float32
device
=
torch
.
device
(
"cuda"
)
@
skipIfNoCuda
class
TestTacotron2GradcheckFloat64CUDA
(
PytorchTestCase
,
Tacotron2LossGradcheckTests
):
dtype
=
torch
.
float64
# gradcheck needs a higher numerical accuracy
device
=
torch
.
device
(
"cuda"
)
test/torchaudio_unittest/example/tacotron2/tacotron2_loss_impl.py
0 → 100644
View file @
1b52e720
import
torch
from
torch.autograd
import
gradcheck
,
gradgradcheck
from
pipeline_tacotron2.loss
import
Tacotron2Loss
from
torchaudio_unittest.common_utils
import
TempDirMixin
class
Tacotron2LossInputMixin
(
TempDirMixin
):
def
_get_inputs
(
self
,
n_mel
=
80
,
n_batch
=
16
,
max_mel_specgram_length
=
300
):
mel_specgram
=
torch
.
rand
(
n_batch
,
n_mel
,
max_mel_specgram_length
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
mel_specgram_postnet
=
torch
.
rand
(
n_batch
,
n_mel
,
max_mel_specgram_length
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
gate_out
=
torch
.
rand
(
n_batch
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
truth_mel_specgram
=
torch
.
rand
(
n_batch
,
n_mel
,
max_mel_specgram_length
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
truth_gate_out
=
torch
.
rand
(
n_batch
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
truth_mel_specgram
.
requires_grad
=
False
truth_gate_out
.
requires_grad
=
False
return
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
,
truth_mel_specgram
,
truth_gate_out
,
)
class
Tacotron2LossShapeTests
(
Tacotron2LossInputMixin
):
def
test_tacotron2_loss_shape
(
self
):
"""Validate the output shape of Tacotron2Loss."""
n_batch
=
16
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
,
truth_mel_specgram
,
truth_gate_out
,
)
=
self
.
_get_inputs
(
n_batch
=
n_batch
)
mel_loss
,
mel_postnet_loss
,
gate_loss
=
Tacotron2Loss
()(
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
),
(
truth_mel_specgram
,
truth_gate_out
)
)
self
.
assertEqual
(
mel_loss
.
size
(),
torch
.
Size
([]))
self
.
assertEqual
(
mel_postnet_loss
.
size
(),
torch
.
Size
([]))
self
.
assertEqual
(
gate_loss
.
size
(),
torch
.
Size
([]))
class
Tacotron2LossTorchscriptTests
(
Tacotron2LossInputMixin
):
def
_assert_torchscript_consistency
(
self
,
fn
,
tensors
):
path
=
self
.
get_temp_path
(
"func.zip"
)
torch
.
jit
.
script
(
fn
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
output
=
fn
(
tensors
[:
3
],
tensors
[
3
:])
ts_output
=
ts_func
(
tensors
[:
3
],
tensors
[
3
:])
self
.
assertEqual
(
ts_output
,
output
)
def
test_tacotron2_loss_torchscript_consistency
(
self
):
"""Validate the torchscript consistency of Tacotron2Loss."""
loss_fn
=
Tacotron2Loss
()
self
.
_assert_torchscript_consistency
(
loss_fn
,
self
.
_get_inputs
())
class
Tacotron2LossGradcheckTests
(
Tacotron2LossInputMixin
):
def
test_tacotron2_loss_gradcheck
(
self
):
"""Performing gradient check on Tacotron2Loss."""
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
,
truth_mel_specgram
,
truth_gate_out
,
)
=
self
.
_get_inputs
()
mel_specgram
.
requires_grad_
(
True
)
mel_specgram_postnet
.
requires_grad_
(
True
)
gate_out
.
requires_grad_
(
True
)
def
_fn
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
,
truth_mel_specgram
,
truth_gate_out
):
loss_fn
=
Tacotron2Loss
()
return
loss_fn
(
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
),
(
truth_mel_specgram
,
truth_gate_out
),
)
gradcheck
(
_fn
,
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
,
truth_mel_specgram
,
truth_gate_out
),
fast_mode
=
True
,
)
gradgradcheck
(
_fn
,
(
mel_specgram
,
mel_specgram_postnet
,
gate_out
,
truth_mel_specgram
,
truth_gate_out
),
fast_mode
=
True
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment