Commit 57c8b97e authored by moto's avatar moto
Browse files

Fix vocoder interface (#1895)

parent c4fc8f90
......@@ -83,7 +83,7 @@ class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def sample_rate(self):
return self._sample_rate
def forward(self, mel_spec, lengths):
def forward(self, mel_spec, lengths=None):
mel_spec = torch.exp(mel_spec)
mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5))
if self._min_level_db is not None:
......@@ -120,7 +120,7 @@ class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def sample_rate(self):
return self._sample_rate
def forward(self, mel_spec, lengths):
def forward(self, mel_spec, lengths=None):
mel_spec = torch.exp(mel_spec)
mel_spec = mel_spec.clone().detach().requires_grad_(True)
spec = self._inv_mel(mel_spec)
......
......@@ -47,7 +47,7 @@ class _Vocoder(ABC):
"""
@abstractmethod
def __call__(self, specgrams: Tensor, lengths: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]:
def __call__(self, specgrams: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
"""Generate waveform from the given input, such as spectrogram
See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage.
......@@ -58,6 +58,7 @@ class _Vocoder(ABC):
The expected shape depends on the implementation.
lengths (Tensor, or None, optional):
The valid length of each sample in the batch. Shape: `(batch, )`.
(Default: `None`)
Returns:
(Tensor, Optional[Tensor]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment