allow Tacotron2 decoding batch_size 1 examples (#2156)
Summary:
it seems to me that the current Tacotron2 model does not allow for decoding batch size 1 examples:
e.g. following code fails. I may have a fix for that.
```python
if __name__ == "__main__":
max_length = 400
n_batch = 1
hdim = 32
dec = _Decoder(
encoder_embedding_dim=hdim,
n_mels = hdim,
n_frames_per_step = 1,
decoder_rnn_dim = 1024,
decoder_max_step = 2000,
decoder_dropout = 0.1,
decoder_early_stopping = True,
attention_rnn_dim = 1024,
attention_hidden_dim = 128,
attention_location_n_filter = 32,
attention_location_kernel_size = 31,
attention_dropout = 0.1,
prenet_dim = 256,
gate_threshold = 0.5)
inp = torch.rand((n_batch, max_length, hdim))
lengths = torch.tensor([max_length]).expand(n_batch).to(inp.device, inp.dtype)
dec(inp, torch.rand((n_batch, hdim, max_length)), lengths)[0]
dec.infer(inp, lengths)[0]
```
Pull Request resolved: https://github.com/pytorch/audio/pull/2156
Reviewed By: carolineechen
Differential Revision: D33744006
Pulled By: nateanl
fbshipit-source-id: 7d04726dfe7e45951ab0007f22f10f90f26379a7
Showing
Please register or sign in to comment