You need to sign in or sign up before continuing.
Commit 9f9b6537 authored by moto's avatar moto
Browse files

Replace `text` with `token` in Tacotron2 API (#1844)

parent cb77a86c
......@@ -1079,20 +1079,20 @@ class Tacotron2(nn.Module):
def forward(
self,
text: Tensor,
text_lengths: Tensor,
tokens: Tensor,
token_lengths: Tensor,
mel_specgram: Tensor,
mel_specgram_lengths: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""Pass the input through the Tacotron2 model. This is in teacher
forcing mode, which is generally used for training.
The input ``text`` should be padded with zeros to length max of ``text_lengths``.
The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
Args:
text (Tensor): The input text to Tacotron2 with shape `(n_batch, max of text_lengths)`.
text_lengths (Tensor): The length of each text with shape `(n_batch, )`.
tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
mel_specgram (Tensor): The target mel spectrogram
with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
......@@ -1107,14 +1107,14 @@ class Tacotron2(nn.Module):
The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
Tensor
Sequence of attention weights from the decoder with
shape `(n_batch, max of mel_specgram_lengths, max of text_lengths)`.
shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
"""
embedded_inputs = self.embedding(text).transpose(1, 2)
embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
encoder_outputs = self.encoder(embedded_inputs, token_lengths)
mel_specgram, gate_outputs, alignments = self.decoder(
encoder_outputs, mel_specgram, memory_lengths=text_lengths
encoder_outputs, mel_specgram, memory_lengths=token_lengths
)
mel_specgram_postnet = self.postnet(mel_specgram)
......@@ -1132,18 +1132,19 @@ class Tacotron2(nn.Module):
return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
@torch.jit.export
def infer(self, text: Tensor, text_lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
r"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (text) and its corresponding lengths (text_lengths). The
sentences (``tokens``) and its corresponding lengths (``lengths``). The
output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder.
The input `text` should be padded with zeros to length max of ``text_lengths``.
The input `tokens` should be padded with zeros to length max of ``lengths``.
Args:
text (Tensor): The input text to Tacotron2 with shape `(n_batch, max of text_lengths)`.
text_lengths (Tensor or None, optional): The length of each text with shape `(n_batch, )`.
If ``None``, it is assumed that the all the texts are valid. Default: ``None``
tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
lengths (Tensor or None, optional):
The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
Returns:
Tensor, Tensor, and Tensor:
......@@ -1153,18 +1154,18 @@ class Tacotron2(nn.Module):
The length of the predicted mel spectrogram with shape `(n_batch, )`.
Tensor
Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of text_lengths)`.
`(n_batch, max of mel_specgram_lengths, max of lengths)`.
"""
n_batch, max_length = text.shape
if text_lengths is None:
text_lengths = torch.tensor([max_length]).expand(n_batch).to(text.device, text.dtype)
n_batch, max_length = tokens.shape
if lengths is None:
lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
assert text_lengths is not None # For TorchScript compiler
assert lengths is not None # For TorchScript compiler
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, lengths)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(
encoder_outputs, text_lengths
encoder_outputs, lengths
)
mel_outputs_postnet = self.postnet(mel_specgram)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment