"...text-generation-inference.git" did not exist on "ed72e9212620d4de10fbe476f0b7af2ab94e4cd7"
Commit f18d01a0 authored by moto's avatar moto
Browse files

Avoid concatenation in loop (#1850)

parent 6321adcf
...@@ -925,12 +925,9 @@ class _Decoder(nn.Module): ...@@ -925,12 +925,9 @@ class _Decoder(nn.Module):
[memory.size(0)], dtype=torch.int32, device=memory.device [memory.size(0)], dtype=torch.int32, device=memory.device
) )
mel_specgrams, gate_outputs, alignments = ( mel_specgrams: List[Tensor] = []
torch.zeros(1, dtype=memory.dtype), gate_outputs: List[Tensor] = []
torch.zeros(1, dtype=memory.dtype), alignments: List[Tensor] = []
torch.zeros(1, dtype=memory.dtype),
)
first_iter = True
while True: while True:
decoder_input = self.prenet(decoder_input) decoder_input = self.prenet(decoder_input)
( (
...@@ -957,15 +954,9 @@ class _Decoder(nn.Module): ...@@ -957,15 +954,9 @@ class _Decoder(nn.Module):
mask, mask,
) )
if first_iter: mel_specgrams.append(mel_specgram.unsqueeze(0))
mel_specgrams = mel_specgram.unsqueeze(0) gate_outputs.append(gate_output.transpose(0, 1))
gate_outputs = gate_output.transpose(0, 1) alignments.append(attention_weights)
alignments = attention_weights
first_iter = False
else:
mel_specgrams = torch.cat((mel_specgrams, mel_specgram.unsqueeze(0)), dim=0)
gate_outputs = torch.cat((gate_outputs, gate_output.transpose(0, 1)), dim=0)
alignments = torch.cat((alignments, attention_weights), dim=0)
dec = torch.le(torch.sigmoid(gate_output), self.gate_threshold).to(torch.int32).squeeze(1) dec = torch.le(torch.sigmoid(gate_output), self.gate_threshold).to(torch.int32).squeeze(1)
...@@ -980,6 +971,10 @@ class _Decoder(nn.Module): ...@@ -980,6 +971,10 @@ class _Decoder(nn.Module):
mel_specgram_lengths += not_finished mel_specgram_lengths += not_finished
decoder_input = mel_specgram decoder_input = mel_specgram
mel_specgrams = torch.cat(mel_specgrams, dim=0)
gate_outputs = torch.cat(gate_outputs, dim=0)
alignments = torch.cat(alignments, dim=0)
mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs( mel_specgrams, gate_outputs, alignments = self._parse_decoder_outputs(
mel_specgrams, gate_outputs, alignments mel_specgrams, gate_outputs, alignments
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment