Commit cea1dc66 authored by popcornell's avatar popcornell Committed by Facebook GitHub Bot
Browse files

allow Tacotron2 decoding batch_size 1 examples (#2156)

Summary:
it seems to me that the current Tacotron2 model does not allow for decoding batch size 1 examples:
e.g. following code fails. I may have a fix for that.

```python
if __name__ == "__main__":
    max_length = 400
    n_batch = 1
    hdim = 32
    dec = _Decoder(
        encoder_embedding_dim=hdim,
        n_mels = hdim,
        n_frames_per_step = 1,
        decoder_rnn_dim = 1024,
        decoder_max_step = 2000,
        decoder_dropout = 0.1,
        decoder_early_stopping = True,
        attention_rnn_dim = 1024,
        attention_hidden_dim = 128,
        attention_location_n_filter = 32,
        attention_location_kernel_size = 31,
        attention_dropout = 0.1,
        prenet_dim = 256,
        gate_threshold = 0.5)

    inp = torch.rand((n_batch, max_length, hdim))
    lengths = torch.tensor([max_length]).expand(n_batch).to(inp.device, inp.dtype)
    dec(inp, torch.rand((n_batch, hdim, max_length)), lengths)[0]
    dec.infer(inp, lengths)[0]
```

Pull Request resolved: https://github.com/pytorch/audio/pull/2156

Reviewed By: carolineechen

Differential Revision: D33744006

Pulled By: nateanl

fbshipit-source-id: 7d04726dfe7e45951ab0007f22f10f90f26379a7
parent 576b02b1
...@@ -749,7 +749,7 @@ class _Decoder(nn.Module): ...@@ -749,7 +749,7 @@ class _Decoder(nn.Module):
) )
mel_outputs += [mel_output.squeeze(1)] mel_outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze()] gate_outputs += [gate_output.squeeze(1)]
alignments += [attention_weights] alignments += [attention_weights]
mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs( mel_specgram, gate_outputs, alignments = self._parse_decoder_outputs(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment