Unverified Commit 6b1c712f authored by moto's avatar moto Committed by GitHub
Browse files

Fix the main loop of tacotron2 decoder inference (#1849)

To handle batched input properly.
parent ccc183da
......@@ -904,6 +904,8 @@ class _Decoder(nn.Module):
alignments (Tensor): Sequence of attention weights from the decoder
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
"""
batch_size, device = memory.size(0), memory.device
decoder_input = self._get_go_frame(memory)
mask = _get_mask_from_lengths(memory_lengths)
......@@ -918,17 +920,12 @@ class _Decoder(nn.Module):
processed_memory,
) = self._initialize_decoder_states(memory)
mel_specgram_lengths = torch.ones(
[memory.size(0)], dtype=torch.int32, device=memory.device
)
not_finished = torch.ones(
[memory.size(0)], dtype=torch.int32, device=memory.device
)
mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
mel_specgrams: List[Tensor] = []
gate_outputs: List[Tensor] = []
alignments: List[Tensor] = []
while True:
for _ in range(self.decoder_max_step):
decoder_input = self.prenet(decoder_input)
(
mel_specgram,
......@@ -957,20 +954,19 @@ class _Decoder(nn.Module):
mel_specgrams.append(mel_specgram.unsqueeze(0))
gate_outputs.append(gate_output.transpose(0, 1))
alignments.append(attention_weights)
mel_specgram_lengths[~finished] += 1
dec = torch.le(torch.sigmoid(gate_output), self.gate_threshold).to(torch.int32).squeeze(1)
not_finished = not_finished * dec
if self.decoder_early_stopping and torch.sum(not_finished) == 0:
break
if len(mel_specgrams) == self.decoder_max_step:
warnings.warn("Reached max decoder steps")
finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
if self.decoder_early_stopping and torch.all(finished):
break
mel_specgram_lengths += not_finished
decoder_input = mel_specgram
if len(mel_specgrams) == self.decoder_max_step:
warnings.warn(
"Reached max decoder steps. The generated spectrogram might not cover "
"the whole transcript.")
mel_specgrams = torch.cat(mel_specgrams, dim=0)
gate_outputs = torch.cat(gate_outputs, dim=0)
alignments = torch.cat(alignments, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment