Unverified Commit 53ccba00 authored by moto's avatar moto Committed by GitHub
Browse files

Fix vocoder interface (#1895)

parent 1a7aec98
...@@ -83,7 +83,7 @@ class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder): ...@@ -83,7 +83,7 @@ class _WaveRNNVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def sample_rate(self): def sample_rate(self):
return self._sample_rate return self._sample_rate
def forward(self, mel_spec, lengths): def forward(self, mel_spec, lengths=None):
mel_spec = torch.exp(mel_spec) mel_spec = torch.exp(mel_spec)
mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5)) mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5))
if self._min_level_db is not None: if self._min_level_db is not None:
...@@ -120,7 +120,7 @@ class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder): ...@@ -120,7 +120,7 @@ class _GriffinLimVocoder(torch.nn.Module, Tacotron2TTSBundle.Vocoder):
def sample_rate(self): def sample_rate(self):
return self._sample_rate return self._sample_rate
def forward(self, mel_spec, lengths): def forward(self, mel_spec, lengths=None):
mel_spec = torch.exp(mel_spec) mel_spec = torch.exp(mel_spec)
mel_spec = mel_spec.clone().detach().requires_grad_(True) mel_spec = mel_spec.clone().detach().requires_grad_(True)
spec = self._inv_mel(mel_spec) spec = self._inv_mel(mel_spec)
......
...@@ -47,7 +47,7 @@ class _Vocoder(ABC): ...@@ -47,7 +47,7 @@ class _Vocoder(ABC):
""" """
@abstractmethod @abstractmethod
def __call__(self, specgrams: Tensor, lengths: Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]]: def __call__(self, specgrams: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
"""Generate waveform from the given input, such as spectrogram """Generate waveform from the given input, such as spectrogram
See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage. See :func:`torchaudio.pipelines.Tacotron2TTSBundle.get_vocoder` for the usage.
...@@ -58,6 +58,7 @@ class _Vocoder(ABC): ...@@ -58,6 +58,7 @@ class _Vocoder(ABC):
The expected shape depends on the implementation. The expected shape depends on the implementation.
lengths (Tensor, or None, optional): lengths (Tensor, or None, optional):
The valid length of each sample in the batch. Shape: `(batch, )`. The valid length of each sample in the batch. Shape: `(batch, )`.
(Default: `None`)
Returns: Returns:
(Tensor, Optional[Tensor]): (Tensor, Optional[Tensor]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment