Commit 137600d0 authored by moto's avatar moto
Browse files

Add `lengths` param to WaveRNN.infer (#1851)

parent ddc49548
...@@ -163,10 +163,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -163,10 +163,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
waveform (Tensor): Reconstructed waveform of size (1, n_time, ). waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel. 1 represents single channel.
""" """
pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0) specgram = specgram.unsqueeze(0)
specgram = torch.nn.functional.pad(specgram, (pad, pad))
if batched: if batched:
specgram = _fold_with_overlap(specgram, timesteps, overlap) specgram = _fold_with_overlap(specgram, timesteps, overlap)
......
...@@ -126,40 +126,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -126,40 +126,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
""" """
upsample_scales = [5, 5, 8] upsample_scales = [5, 5, 8]
n_rnn = 512 n_rnn = 128
n_fc = 512 n_fc = 128
n_classes = 512 n_classes = 128
hop_length = 200 hop_length = 200
n_batch = 2 n_batch = 2
n_time = 200 n_time = 50
n_freq = 100 n_freq = 25
n_output = 256 n_output = 64
n_res_block = 10 n_res_block = 2
n_hidden = 128 n_hidden = 32
kernel_size = 5 kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output) n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
x = torch.rand(n_batch, n_freq, n_time) x = torch.rand(n_batch, n_freq, n_time)
out = model.infer(x) lengths = torch.tensor([n_time, n_time // 2])
out, waveform_lengths = model.infer(x, lengths)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1)) assert out.size() == (n_batch, 1, hop_length * n_time)
assert waveform_lengths[0] == hop_length * n_time
assert waveform_lengths[1] == hop_length * n_time // 2
def test_torchscript_infer(self): def test_torchscript_infer(self):
"""Scripted model outputs the same as eager mode""" """Scripted model outputs the same as eager mode"""
upsample_scales = [5, 5, 8] upsample_scales = [5, 5, 8]
n_rnn = 512 n_rnn = 128
n_fc = 512 n_fc = 128
n_classes = 512 n_classes = 128
hop_length = 200 hop_length = 200
n_batch = 2 n_batch = 2
n_time = 200 n_time = 50
n_freq = 100 n_freq = 25
n_output = 256 n_output = 64
n_res_block = 10 n_res_block = 2
n_hidden = 128 n_hidden = 32
kernel_size = 5 kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block, model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
......
from typing import List, Tuple, Dict, Any from typing import List, Tuple, Dict, Any, Optional
import math import math
import torch import torch
...@@ -182,6 +182,7 @@ class UpsampleNetwork(nn.Module): ...@@ -182,6 +182,7 @@ class UpsampleNetwork(nn.Module):
total_scale = 1 total_scale = 1
for upsample_scale in upsample_scales: for upsample_scale in upsample_scales:
total_scale *= upsample_scale total_scale *= upsample_scale
self.total_scale: int = total_scale
self.indent = (kernel_size - 1) // 2 * total_scale self.indent = (kernel_size - 1) // 2 * total_scale
self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size) self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
...@@ -265,6 +266,7 @@ class WaveRNN(nn.Module): ...@@ -265,6 +266,7 @@ class WaveRNN(nn.Module):
super().__init__() super().__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
self.n_rnn = n_rnn self.n_rnn = n_rnn
self.n_aux = n_output // 4 self.n_aux = n_output // 4
self.hop_length = hop_length self.hop_length = hop_length
...@@ -351,24 +353,35 @@ class WaveRNN(nn.Module): ...@@ -351,24 +353,35 @@ class WaveRNN(nn.Module):
return x.unsqueeze(1) return x.unsqueeze(1)
@torch.jit.export @torch.jit.export
def infer(self, specgram: Tensor) -> Tensor: def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
r"""Inference method of WaveRNN. r"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss. network is trained on cross entropy loss.
Args: Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time). specgram (Tensor):
Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
Return: lengths (Tensor or None, optional):
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time). Indicates the valid length in of each spectrogram in time axis.
Shape: `(n_batch, )`.
Returns:
Tensor and optional Tensor:
Tensor
The inferred waveform of size `(n_batch, 1, n_time)`.
1 stands for a single channel. 1 stands for a single channel.
Tensor or None
The valid lengths of each waveform in the batch. Size `(n_batch, )`.
""" """
device = specgram.device device = specgram.device
dtype = specgram.dtype dtype = specgram.dtype
specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
specgram, aux = self.upsample(specgram) specgram, aux = self.upsample(specgram)
if lengths is not None:
lengths = lengths * self.upsample.total_scale
output: List[Tensor] = [] output: List[Tensor] = []
b_size, _, seq_len = specgram.size() b_size, _, seq_len = specgram.size()
...@@ -410,7 +423,7 @@ class WaveRNN(nn.Module): ...@@ -410,7 +423,7 @@ class WaveRNN(nn.Module):
output.append(x) output.append(x)
return torch.stack(output).permute(1, 2, 0) return torch.stack(output).permute(1, 2, 0), lengths
def wavernn(checkpoint_name: str) -> WaveRNN: def wavernn(checkpoint_name: str) -> WaveRNN:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment