Unverified Commit 976f56e8 authored by moto's avatar moto Committed by GitHub
Browse files

Make `text_length` optional in `Tacotron2.infer` (#1839)

parent fd7fcf93
......@@ -1130,7 +1130,7 @@ class Tacotron2(nn.Module):
return mel_specgram, mel_specgram_postnet, gate_outputs, alignments
@torch.jit.export
def infer(self, text: Tensor, text_lengths: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
def infer(self, text: Tensor, text_lengths: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
r"""Using Tacotron2 for inference. The input is a batch of encoded
sentences (text) and its corresponding lengths (text_lengths). The
output is the generated mel spectrograms, its corresponding lengths, and
......@@ -1140,7 +1140,8 @@ class Tacotron2(nn.Module):
Args:
text (Tensor): The input text to Tacotron2 with shape (n_batch, max of ``text_lengths``).
text_lengths (Tensor): The length of each text with shape (n_batch, ).
text_lengths (Tensor or None, optional): The length of each text with shape `(n_batch, )`.
If ``None``, it is assumed that the all the texts are valid. Default: ``None``
Return:
mel_specgram (Tensor): The predicted mel spectrogram
......@@ -1150,6 +1151,11 @@ class Tacotron2(nn.Module):
alignments (Tensor): Sequence of attention weights from the decoder.
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
n_batch, max_length = text.shape
if text_lengths is None:
text_lengths = torch.tensor([max_length]).expand(n_batch).to(text.device, text.dtype)
assert text_lengths is not None # For TorchScript compiler
embedded_inputs = self.embedding(text).transpose(1, 2)
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
......@@ -1160,7 +1166,6 @@ class Tacotron2(nn.Module):
mel_outputs_postnet = self.postnet(mel_specgram)
mel_outputs_postnet = mel_specgram + mel_outputs_postnet
n_batch = mel_outputs_postnet.size(0)
alignments = alignments.unfold(1, n_batch, n_batch).transpose(0, 2)
return mel_outputs_postnet, mel_specgram_lengths, alignments
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment