Commit 137600d0 authored by moto's avatar moto
Browse files

Add `lengths` param to WaveRNN.infer (#1851)

parent ddc49548
......@@ -163,10 +163,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
"""
pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0)
specgram = torch.nn.functional.pad(specgram, (pad, pad))
if batched:
specgram = _fold_with_overlap(specgram, timesteps, overlap)
......
......@@ -126,40 +126,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_classes = 512
n_rnn = 128
n_fc = 128
n_classes = 128
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
n_time = 50
n_freq = 25
n_output = 64
n_res_block = 2
n_hidden = 32
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
x = torch.rand(n_batch, n_freq, n_time)
out = model.infer(x)
lengths = torch.tensor([n_time, n_time // 2])
out, waveform_lengths = model.infer(x, lengths)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1))
assert out.size() == (n_batch, 1, hop_length * n_time)
assert waveform_lengths[0] == hop_length * n_time
assert waveform_lengths[1] == hop_length * n_time // 2
def test_torchscript_infer(self):
"""Scripted model outputs the same as eager mode"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_classes = 512
n_rnn = 128
n_fc = 128
n_classes = 128
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
n_time = 50
n_freq = 25
n_output = 64
n_res_block = 2
n_hidden = 32
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
......
from typing import List, Tuple, Dict, Any
from typing import List, Tuple, Dict, Any, Optional
import math
import torch
......@@ -182,6 +182,7 @@ class UpsampleNetwork(nn.Module):
total_scale = 1
for upsample_scale in upsample_scales:
total_scale *= upsample_scale
self.total_scale: int = total_scale
self.indent = (kernel_size - 1) // 2 * total_scale
self.resnet = MelResNet(n_res_block, n_freq, n_hidden, n_output, kernel_size)
......@@ -265,6 +266,7 @@ class WaveRNN(nn.Module):
super().__init__()
self.kernel_size = kernel_size
self._pad = (kernel_size - 1 if kernel_size % 2 else kernel_size) // 2
self.n_rnn = n_rnn
self.n_aux = n_output // 4
self.hop_length = hop_length
......@@ -351,24 +353,35 @@ class WaveRNN(nn.Module):
return x.unsqueeze(1)
@torch.jit.export
def infer(self, specgram: Tensor) -> Tensor:
def infer(self, specgram: Tensor, lengths: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]:
r"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time).
Return:
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time).
specgram (Tensor):
Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
lengths (Tensor or None, optional):
Indicates the valid length in of each spectrogram in time axis.
Shape: `(n_batch, )`.
Returns:
Tensor and optional Tensor:
Tensor
The inferred waveform of size `(n_batch, 1, n_time)`.
1 stands for a single channel.
Tensor or None
The valid lengths of each waveform in the batch. Size `(n_batch, )`.
"""
device = specgram.device
dtype = specgram.dtype
specgram = torch.nn.functional.pad(specgram, (self._pad, self._pad))
specgram, aux = self.upsample(specgram)
if lengths is not None:
lengths = lengths * self.upsample.total_scale
output: List[Tensor] = []
b_size, _, seq_len = specgram.size()
......@@ -410,7 +423,7 @@ class WaveRNN(nn.Module):
output.append(x)
return torch.stack(output).permute(1, 2, 0)
return torch.stack(output).permute(1, 2, 0), lengths
def wavernn(checkpoint_name: str) -> WaveRNN:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment