Commit 49c48f93 authored by moto's avatar moto
Browse files

Fix the main loop of tacotron2 decoder inference (#1849)

To handle batched input properly.
parent ab97afa0
...@@ -904,6 +904,8 @@ class _Decoder(nn.Module): ...@@ -904,6 +904,8 @@ class _Decoder(nn.Module):
alignments (Tensor): Sequence of attention weights from the decoder alignments (Tensor): Sequence of attention weights from the decoder
with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``). with shape (n_batch, max of ``mel_specgram_lengths``, max of ``text_lengths``).
""" """
batch_size, device = memory.size(0), memory.device
decoder_input = self._get_go_frame(memory) decoder_input = self._get_go_frame(memory)
mask = _get_mask_from_lengths(memory_lengths) mask = _get_mask_from_lengths(memory_lengths)
...@@ -918,17 +920,12 @@ class _Decoder(nn.Module): ...@@ -918,17 +920,12 @@ class _Decoder(nn.Module):
processed_memory, processed_memory,
) = self._initialize_decoder_states(memory) ) = self._initialize_decoder_states(memory)
mel_specgram_lengths = torch.ones( mel_specgram_lengths = torch.zeros([batch_size], dtype=torch.int32, device=device)
[memory.size(0)], dtype=torch.int32, device=memory.device finished = torch.zeros([batch_size], dtype=torch.bool, device=device)
)
not_finished = torch.ones(
[memory.size(0)], dtype=torch.int32, device=memory.device
)
mel_specgrams: List[Tensor] = [] mel_specgrams: List[Tensor] = []
gate_outputs: List[Tensor] = [] gate_outputs: List[Tensor] = []
alignments: List[Tensor] = [] alignments: List[Tensor] = []
while True: for _ in range(self.decoder_max_step):
decoder_input = self.prenet(decoder_input) decoder_input = self.prenet(decoder_input)
( (
mel_specgram, mel_specgram,
...@@ -957,20 +954,19 @@ class _Decoder(nn.Module): ...@@ -957,20 +954,19 @@ class _Decoder(nn.Module):
mel_specgrams.append(mel_specgram.unsqueeze(0)) mel_specgrams.append(mel_specgram.unsqueeze(0))
gate_outputs.append(gate_output.transpose(0, 1)) gate_outputs.append(gate_output.transpose(0, 1))
alignments.append(attention_weights) alignments.append(attention_weights)
mel_specgram_lengths[~finished] += 1
dec = torch.le(torch.sigmoid(gate_output), self.gate_threshold).to(torch.int32).squeeze(1) finished |= torch.sigmoid(gate_output.squeeze(1)) > self.gate_threshold
if self.decoder_early_stopping and torch.all(finished):
not_finished = not_finished * dec
if self.decoder_early_stopping and torch.sum(not_finished) == 0:
break
if len(mel_specgrams) == self.decoder_max_step:
warnings.warn("Reached max decoder steps")
break break
mel_specgram_lengths += not_finished
decoder_input = mel_specgram decoder_input = mel_specgram
if len(mel_specgrams) == self.decoder_max_step:
warnings.warn(
"Reached max decoder steps. The generated spectrogram might not cover "
"the whole transcript.")
mel_specgrams = torch.cat(mel_specgrams, dim=0) mel_specgrams = torch.cat(mel_specgrams, dim=0)
gate_outputs = torch.cat(gate_outputs, dim=0) gate_outputs = torch.cat(gate_outputs, dim=0)
alignments = torch.cat(alignments, dim=0) alignments = torch.cat(alignments, dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment