Commit 9f9b6537 authored by moto's avatar moto
Browse files

Replace `text` with `token` in Tacotron2 API (#1844)

parent cb77a86c
...@@ -1079,20 +1079,20 @@ class Tacotron2(nn.Module): ...@@ -1079,20 +1079,20 @@ class Tacotron2(nn.Module):
def forward( def forward(
self, self,
text: Tensor, tokens: Tensor,
text_lengths: Tensor, token_lengths: Tensor,
mel_specgram: Tensor, mel_specgram: Tensor,
mel_specgram_lengths: Tensor, mel_specgram_lengths: Tensor,
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
r"""Pass the input through the Tacotron2 model. This is in teacher r"""Pass the input through the Tacotron2 model. This is in teacher
forcing mode, which is generally used for training. forcing mode, which is generally used for training.
The input ``text`` should be padded with zeros to length max of ``text_lengths``. The input ``tokens`` should be padded with zeros to length max of ``token_lengths``.
The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``. The input ``mel_specgram`` should be padded with zeros to length max of ``mel_specgram_lengths``.
Args: Args:
text (Tensor): The input text to Tacotron2 with shape `(n_batch, max of text_lengths)`. tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of token_lengths)`.
text_lengths (Tensor): The length of each text with shape `(n_batch, )`. token_lengths (Tensor): The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
mel_specgram (Tensor): The target mel spectrogram mel_specgram (Tensor): The target mel spectrogram
with shape `(n_batch, n_mels, max of mel_specgram_lengths)`. with shape `(n_batch, n_mels, max of mel_specgram_lengths)`.
mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`. mel_specgram_lengths (Tensor): The length of each mel spectrogram with shape `(n_batch, )`.
...@@ -1107,14 +1107,14 @@ class Tacotron2(nn.Module): ...@@ -1107,14 +1107,14 @@ class Tacotron2(nn.Module):
The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`. The output for stop token at each time step with shape `(n_batch, max of mel_specgram_lengths)`.
Tensor Tensor
Sequence of attention weights from the decoder with Sequence of attention weights from the decoder with
shape `(n_batch, max of mel_specgram_lengths, max of text_lengths)`. shape `(n_batch, max of mel_specgram_lengths, max of token_lengths)`.
""" """
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, token_lengths)
mel_specgram, gate_outputs, alignments = self.decoder( mel_specgram, gate_outputs, alignments = self.decoder(
encoder_outputs, mel_specgram, memory_lengths=text_lengths encoder_outputs, mel_specgram, memory_lengths=token_lengths
) )
mel_specgram_postnet = self.postnet(mel_specgram) mel_specgram_postnet = self.postnet(mel_specgram)
...@@ -1132,18 +1132,19 @@ class Tacotron2(nn.Module): ...@@ -1132,18 +1132,19 @@ class Tacotron2(nn.Module):
return mel_specgram, mel_specgram_postnet, gate_outputs, alignments return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
@torch.jit.export @torch.jit.export
def infer(self, text: Tensor, text_lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]: def infer(self, tokens: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
r"""Using Tacotron2 for inference. The input is a batch of encoded r"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (text) and its corresponding lengths (text_lengths). The sentences (``tokens``) and its corresponding lengths (``lengths``). The
output is the generated mel spectrograms, its corresponding lengths, and output is the generated mel spectrograms, its corresponding lengths, and
the attention weights from the decoder. the attention weights from the decoder.
The input `text` should be padded with zeros to length max of ``text_lengths``. The input `tokens` should be padded with zeros to length max of ``lengths``.
Args: Args:
text (Tensor): The input text to Tacotron2 with shape `(n_batch, max of text_lengths)`. tokens (Tensor): The input tokens to Tacotron2 with shape `(n_batch, max of lengths)`.
text_lengths (Tensor or None, optional): The length of each text with shape `(n_batch, )`. lengths (Tensor or None, optional):
If ``None``, it is assumed that the all the texts are valid. Default: ``None`` The valid length of each sample in ``tokens`` with shape `(n_batch, )`.
If ``None``, it is assumed that the all the tokens are valid. Default: ``None``
Returns: Returns:
Tensor, Tensor, and Tensor: Tensor, Tensor, and Tensor:
...@@ -1153,18 +1154,18 @@ class Tacotron2(nn.Module): ...@@ -1153,18 +1154,18 @@ class Tacotron2(nn.Module):
The length of the predicted mel spectrogram with shape `(n_batch, )`. The length of the predicted mel spectrogram with shape `(n_batch, )`.
Tensor Tensor
Sequence of attention weights from the decoder with shape Sequence of attention weights from the decoder with shape
`(n_batch, max of mel_specgram_lengths, max of text_lengths)`. `(n_batch, max of mel_specgram_lengths, max of lengths)`.
""" """
n_batch, max_length = text.shape n_batch, max_length = tokens.shape
if text_lengths is None: if lengths is None:
text_lengths = torch.tensor([max_length]).expand(n_batch).to(text.device, text.dtype) lengths = torch.tensor([max_length]).expand(n_batch).to(tokens.device, tokens.dtype)
assert text_lengths is not None # For TorchScript compiler assert lengths is not None # For TorchScript compiler
embedded_inputs = self.embedding(text).transpose(1, 2) embedded_inputs = self.embedding(tokens).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths) encoder_outputs = self.encoder(embedded_inputs, lengths)
mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer( mel_specgram, mel_specgram_lengths, _, alignments = self.decoder.infer(
encoder_outputs, text_lengths encoder_outputs, lengths
) )
mel_outputs_postnet = self.postnet(mel_specgram) mel_outputs_postnet = self.postnet(mel_specgram)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment