Unverified Commit 15bc554f authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Add Tacotron2 inference method (#1648)

parent 90c0edc5
from typing import Tuple
import torch
from torch import Tensor
from torchaudio.prototype.tacotron2 import Tacotron2, _Encoder, _Decoder
from torchaudio_unittest.common_utils import (
TestBaseMixin,
......@@ -6,6 +8,26 @@ from torchaudio_unittest.common_utils import (
)
class Tacotron2InferenceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, text: Tensor, text_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
return self.model.infer(text, text_lengths)
class Tacotron2DecoderInferenceWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, memory: Tensor, memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
return self.model.infer(memory, memory_lengths)
class TorchscriptConsistencyMixin(TempDirMixin):
r"""Mixin to provide easy access assert torchscript consistency"""
......@@ -24,6 +46,7 @@ class TorchscriptConsistencyMixin(TempDirMixin):
class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
def test_tacotron2_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Encoder."""
n_batch, n_seq, encoder_embedding_dim = 16, 64, 512
......@@ -60,27 +83,29 @@ class Tacotron2EncoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
assert out.size() == (n_batch, n_seq, encoder_embedding_dim)
def _get_decoder_model(n_mels=80, encoder_embedding_dim=512):
def _get_decoder_model(n_mels=80, encoder_embedding_dim=512,
decoder_max_step=2000, gate_threshold=0.5):
model = _Decoder(
n_mels=n_mels,
n_frames_per_step=1,
encoder_embedding_dim=encoder_embedding_dim,
decoder_rnn_dim=1024,
decoder_max_step=2000,
decoder_max_step=decoder_max_step,
decoder_dropout=0.1,
decoder_early_stopping=False,
decoder_early_stopping=True,
attention_rnn_dim=1024,
attention_hidden_dim=128,
attention_location_n_filter=32,
attention_location_kernel_size=31,
attention_dropout=0.1,
prenet_dim=256,
gate_threshold=0.5,
gate_threshold=gate_threshold,
)
return model
class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
def test_decoder_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16
......@@ -125,16 +150,81 @@ class Tacotron2DecoderTests(TestBaseMixin, TorchscriptConsistencyMixin):
)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
mel_outputs, gate_outputs, alignments = model(
mel_specgram, gate_outputs, alignments = model(
memory, decoder_inputs, memory_lengths
)
assert mel_outputs.size() == (n_batch, n_mels, n_time_steps)
assert mel_specgram.size() == (n_batch, n_mels, n_time_steps)
assert gate_outputs.size() == (n_batch, n_time_steps)
assert alignments.size() == (n_batch, n_time_steps, n_seq)
def test_decoder_inference_torchscript_consistency(self):
r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16
n_mels = 80
n_seq = 200
encoder_embedding_dim = 256
decoder_max_step = 300 # make inference more efficient
gate_threshold = 0.505 # make inference more efficient
model = _get_decoder_model(
n_mels=n_mels,
encoder_embedding_dim=encoder_embedding_dim,
decoder_max_step=decoder_max_step,
gate_threshold=gate_threshold,
)
model = model.to(self.device).eval()
memory = torch.rand(
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
def _get_tacotron2_model(n_mels):
model_wrapper = Tacotron2DecoderInferenceWrapper(model)
self._assert_torchscript_consistency(model_wrapper, (memory, memory_lengths))
def test_decoder_inference_output_shape(self):
r"""Validate the torchscript consistency of a Decoder."""
n_batch = 16
n_mels = 80
n_seq = 200
encoder_embedding_dim = 256
decoder_max_step = 300 # make inference more efficient
gate_threshold = 0.505 # if set to 0.5, the model will only run one step
model = _get_decoder_model(
n_mels=n_mels,
encoder_embedding_dim=encoder_embedding_dim,
decoder_max_step=decoder_max_step,
gate_threshold=gate_threshold,
)
model = model.to(self.device).eval()
memory = torch.rand(
n_batch, n_seq, encoder_embedding_dim, dtype=self.dtype, device=self.device
)
memory_lengths = torch.ones(n_batch, dtype=torch.int32, device=self.device)
mel_specgram, mel_specgram_lengths, gate_outputs, alignments = model.infer(
memory, memory_lengths
)
assert len(mel_specgram.size()) == 3
assert mel_specgram.size()[:-1] == (n_batch, n_mels, )
assert mel_specgram.size()[2] == mel_specgram_lengths.max().item()
assert len(mel_specgram_lengths.size()) == 1
assert mel_specgram_lengths.size()[0] == n_batch
assert mel_specgram_lengths.max().item() <= model.decoder_max_step
assert len(gate_outputs.size()) == 2
assert gate_outputs.size()[0] == n_batch
assert gate_outputs.size()[1] == mel_specgram_lengths.max().item()
assert len(alignments.size()) == 2
assert alignments.size()[0] == n_seq
assert alignments.size()[1] == mel_specgram_lengths.max().item() * n_batch
def _get_tacotron2_model(n_mels, decoder_max_step=2000, gate_threshold=0.5):
return Tacotron2(
mask_padding=False,
n_mels=n_mels,
......@@ -145,7 +235,7 @@ def _get_tacotron2_model(n_mels):
encoder_n_convolution=3,
encoder_kernel_size=5,
decoder_rnn_dim=1024,
decoder_max_step=2000,
decoder_max_step=decoder_max_step,
decoder_dropout=0.1,
decoder_early_stopping=True,
attention_rnn_dim=1024,
......@@ -157,13 +247,14 @@ def _get_tacotron2_model(n_mels):
postnet_n_convolution=5,
postnet_kernel_size=5,
postnet_embedding_dim=512,
gate_threshold=0.5,
gate_threshold=gate_threshold,
)
class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin):
def _get_inputs(
self, n_mels, n_batch: int, max_mel_specgram_length: int, max_text_length: int
self, n_mels: int, n_batch: int, max_mel_specgram_length: int, max_text_length: int
):
text = torch.randint(
0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device
......@@ -236,3 +327,59 @@ class Tacotron2Tests(TestBaseMixin, TorchscriptConsistencyMixin):
mel_out.sum().backward(retain_graph=True)
mel_out_postnet.sum().backward(retain_graph=True)
gate_outputs.sum().backward()
def _get_inference_inputs(self, n_batch: int, max_text_length: int):
text = torch.randint(
0, 148, (n_batch, max_text_length), dtype=torch.int32, device=self.device
)
text_lengths = max_text_length * torch.ones(
(n_batch,), dtype=torch.int32, device=self.device
)
return text, text_lengths
def test_tacotron2_inference_torchscript_consistency(self):
r"""Validate the torchscript consistency of Tacotron2 inference function."""
n_batch = 16
n_mels = 40
max_text_length = 100
decoder_max_step = 200 # make inference more efficient
gate_threshold = 0.51 # if set to 0.5, the model will only run one step
model = _get_tacotron2_model(
n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold
).to(self.device).eval()
inputs = self._get_inference_inputs(n_batch, max_text_length)
model_wrapper = Tacotron2InferenceWrapper(model)
self._assert_torchscript_consistency(model_wrapper, inputs)
def test_tacotron2_inference_output_shape(self):
r"""Feed tensors with specific shape to Tacotron2 inference function and validate
that it outputs with a tensor with expected shape.
"""
n_batch = 16
n_mels = 40
max_text_length = 100
decoder_max_step = 200 # make inference more efficient
gate_threshold = 0.51 # if set to 0.5, the model will only run one step
model = _get_tacotron2_model(
n_mels, decoder_max_step=decoder_max_step, gate_threshold=gate_threshold
).to(self.device).eval()
inputs = self._get_inference_inputs(n_batch, max_text_length)
mel_out, mel_specgram_lengths, alignments = model.infer(*inputs)
# There is no guarantee on exactly what max_mel_specgram_length should be
# We only know that it should be smaller than model.decoder.decoder_max_step
assert len(mel_out.size()) == 3
assert mel_out.size()[:2] == (n_batch, n_mels, )
assert mel_out.size()[2] == mel_specgram_lengths.max().item()
assert len(mel_specgram_lengths.size()) == 1
assert mel_specgram_lengths.size()[0] == n_batch
assert mel_specgram_lengths.max().item() <= model.decoder.decoder_max_step
assert len(alignments.size()) == 3
assert alignments.size()[0] == n_batch
assert alignments.size()[1] == mel_specgram_lengths.max().item()
assert alignments.size()[2] == max_text_length
......@@ -25,6 +25,7 @@
#
# *****************************************************************************
import warnings
from math import sqrt
from typing import Tuple, List, Optional, Union
......@@ -614,12 +615,12 @@ class _Decoder(nn.Module):
return decoder_inputs
def _parse_decoder_outputs(
self, mel_outputs: Tensor, gate_outputs: Tensor, alignments: Tensor
self, mel_specgram: Tensor, gate_outputs: Tensor, alignments: Tensor
) -> Tuple[Tensor, Tensor, Tensor]:
r"""Prepares decoder outputs for output
Args:
mel_outputs (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
mel_specgram (Tensor): mel spectrogram with shape (max of ``mel_specgram_lengths``, n_batch, ``n_mels``)
gate_outputs (Tensor): predicted stop token with shape (max of ``mel_specgram_lengths``, n_batch)
alignments (Tensor): sequence of attention weights from the decoder
with shape (max of ``mel_specgram_lengths``, n_batch, max of ``text_lengths``)
......@@ -636,7 +637,7 @@ class _Decoder(nn.Module):
# (mel_specgram_lengths.max(), n_batch) -> (n_batch, mel_specgram_lengths.max())
gate_outputs = gate_outputs.transpose(0, 1).contiguous()
# (mel_specgram_lengths.max(), n_batch, n_mels) -> (n_batch, mel_specgram_lengths.max(), n_mels)
mel_specgram = mel_outputs.transpose(0, 1).contiguous()
mel_specgram = mel_specgram.transpose(0, 1).contiguous()
# decouple frames per step
shape = (mel_specgram.shape[0], -1, self.n_mels)
mel_specgram = mel_specgram.view(*shape)
......@@ -805,6 +806,128 @@ class _Decoder(nn.Module):
return mel_specgram, gate_outputs, alignments
def _get_go_frame(self, memory: Tensor) -> Tensor:
"""Gets all zeros frames to use as the first decoder input
args:
memory (Tensor): Encoder outputs
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
returns:
decoder_input (Tensor): All zeros frames with shape(n_batch, ``n_mels`` * ``n_frame_per_step``).
"""
n_batch = memory.size(0)
dtype = memory.dtype
device = memory.device
decoder_input = torch.zeros(
n_batch, self.n_mels * self.n_frames_per_step, dtype=dtype, device=device
)
return decoder_input
@torch.jit.export
def infer(self,
memory: Tensor,
memory_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
"""Decoder inference
Args:
memory (Tensor): Encoder outputs
with shape (n_batch, max of ``text_lengths``, ``encoder_embedding_dim``).
memory_lengths (Tensor): Encoder output lengths for attention masking
(the same as ``text_lengths``) with shape (n_batch, ).
Returns:
mel_specgram (Tensor): Predicted mel spectrogram
with shape (n_batch, ``n_mels``, max of ``mel_specgram_lengths``).
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
gate_outputs (Tensor): Predicted stop token for each timestep
with shape (n_batch, max of ``mel_specgram_lengths``).
alignments (Tensor): Sequence of attention weights from the decoder
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
decoder_input = self._get_go_frame(memory)
mask = _get_mask_from_lengths(memory_lengths)
(
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
processed_memory,
) = self._initialize_decoder_states(memory)
mel_specgram_lengths = torch.ones(
[memory.size(0)], dtype=torch.int32, device=memory.device
)
not_finished = torch.ones(
[memory.size(0)], dtype=torch.int32, device=memory.device
)
mel_specgrams, gate_outputs, alignments = (
torch.zeros(1, dtype=memory.dtype),
torch.zeros(1, dtype=memory.dtype),
torch.zeros(1, dtype=memory.dtype),
)
first_iter = True
while True:
decoder_input = self.prenet(decoder_input)
(
mel_specgram,
gate_output,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
) = self.decode(
decoder_input,
attention_hidden,
attention_cell,
decoder_hidden,
decoder_cell,
attention_weights,
attention_weights_cum,
attention_context,
memory,
processed_memory,
mask,
)
if first_iter:
mel_specgrams = mel_specgram.unsqueeze(0)
gate_outputs = gate_output.transpose(0, 1)
alignments = attention_weights
first_iter = False
else:
mel_specgrams = torch.cat((mel_specgrams, mel_specgram.unsqueeze(0)), dim=0)
gate_outputs = torch.cat((gate_outputs, gate_output.transpose(0, 1)), dim=0)
alignments = torch.cat((alignments, attention_weights), dim=0)
dec = torch.le(torch.sigmoid(gate_output), self.gate_threshold).to(torch.int32).squeeze(1)
not_finished = not_finished * dec
if self.decoder_early_stopping and torch.sum(not_finished) == 0:
break
if len(mel_specgrams) == self.decoder_max_step:
warnings.warn("Reached max decoder steps")
break
mel_specgram_lengths += not_finished
decoder_input = mel_specgram
mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(
mel_specgrams, gate_outputs, alignments
)
return mel_specgrams, mel_specgram_lengths, gate_outputs, alignments
class Tacotron2(nn.Module):
r"""Tacotron2 model based on the implementation from
......@@ -947,3 +1070,38 @@ class Tacotron2(nn.Module):
gate_outputs.masked_fill_(mask[:, 0, :], 1e3)
return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
@torch.jit.export
def infer(self, text: Tensor, text_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
r"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (text) and its corresponding lengths (text_lengths). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `text` should be padded with zeros to length max of ``text_lengths``.
Args:
text (Tensor): the input text to Tacotron2. (n_batch, max of ``text_lengths``)
text_lengths (Tensor): the length of each text (n_batch)
Return:
mel_specgram (Tensor): the predicted mel spectrogram
with shape (n_batch, n_mels, max of ``mel_specgram_lengths.max()``)
mel_specgram_lengths (Tensor): the length of the predicted mel spectrogram (n_batch, ))
alignments (Tensor): Sequence of attention weights from the decoder.
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(
encoder_outputs, text_lengths
)
mel_outputs_postnet = self.postnet(mel_specgram)
mel_outputs_postnet = mel_specgram + mel_outputs_postnet
n_batch = mel_outputs_postnet.size(0)
alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
return mel_outputs_postnet, mel_specgram_lengths, alignments
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment