Unverified Commit 3bb5feb5 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Refactor WaveRNN infer and move it to the codebase (#1704)

parent 63f0614b
......@@ -106,6 +106,8 @@ WaveRNN
.. automethod:: forward
.. automethod:: infer
Factory Functions
-----------------
......
......@@ -21,10 +21,6 @@ def parse_args():
"--jit", default=False, action="store_true",
help="If used, the model and inference function is jitted."
)
parser.add_argument(
"--loss", default="crossentropy", choices=["crossentropy"],
type=str, help="The type of loss the pretrained model is trained on.",
)
parser.add_argument(
"--no-batch-inference", default=False, action="store_true",
help="Don't use batch inference."
......@@ -39,11 +35,11 @@ def parse_args():
help="Select the WaveRNN checkpoint."
)
parser.add_argument(
"--batch-timesteps", default=11000, type=int,
"--batch-timesteps", default=100, type=int,
help="The time steps for each batch. Only used when batch inference is used",
)
parser.add_argument(
"--batch-overlap", default=550, type=int,
"--batch-overlap", default=5, type=int,
help="The overlapping time steps between batches. Only used when batch inference is used",
)
args = parser.parse_args()
......@@ -79,13 +75,12 @@ def main(args):
with torch.no_grad():
output = wavernn_inference_model(mel_specgram.to(device),
loss_name=args.loss,
mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps,
overlap=args.batch_overlap,)
torchaudio.save(args.output_wav_path, output.reshape(1, -1), sample_rate=sample_rate)
torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
if __name__ == "__main__":
......
......@@ -21,18 +21,12 @@
# *****************************************************************************
from typing import List
from torchaudio.models.wavernn import WaveRNN
import torch
import torch.nn.functional as F
import torchaudio
from torch import Tensor
from processing import (
normalized_waveform_to_bits,
bits_to_normalized_waveform,
)
from processing import normalized_waveform_to_bits
class WaveRNNInferenceWrapper(torch.nn.Module):
......@@ -53,12 +47,12 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[h7, h8, h9, h10]]
Args:
x (tensor): Upsampled conditioning channels with shape (1, timesteps, channel).
x (tensor): Upsampled conditioning channels of size (1, timesteps, channel).
timesteps (int): Timesteps for each index of batch.
overlap (int): Timesteps for both xfade and rnn warmup.
Return:
folded (tensor): folded tensor with shape (n_folds, timesteps + 2 * overlap, channel).
folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
'''
_, channels, total_len = x.size()
......@@ -98,15 +92,15 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
Args:
y (Tensor): Batched sequences of audio samples with shape
(num_folds, timesteps + 2 * overlap).
y (Tensor): Batched sequences of audio samples of size
(num_folds, channels, timesteps + 2 * overlap).
overlap (int): Timesteps for both xfade and rnn warmup.
Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor with shape (total_len).
unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
'''
num_folds, length = y.shape
num_folds, channels, length = y.shape
timesteps = length - 2 * overlap
total_len = num_folds * (timesteps + overlap) + overlap
......@@ -126,16 +120,16 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
fade_out = torch.cat([linear, fade_out])
# Apply the gain to the overlap samples
y[:, :overlap] *= fade_in
y[:, -overlap:] *= fade_out
y[:, :, :overlap] *= fade_in
y[:, :, -overlap:] *= fade_out
unfolded = torch.zeros((total_len), dtype=y.dtype, device=y.device)
unfolded = torch.zeros((channels, total_len), dtype=y.dtype, device=y.device)
# Loop to add up all the samples
for i in range(num_folds):
start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap
unfolded[start:end] += y[i]
unfolded[:, start:end] += y[i]
return unfolded
......@@ -143,11 +137,11 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
r"""Pad the given tensor.
Args:
x (Tensor): The tensor to pad with shape (n_batch, n_mels, time).
x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
Return:
padded (Tensor): The padded tensor with shape (n_batch, n_mels, time).
padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
"""
b, c, t = x.size()
total = t + 2 * pad if side == 'both' else t + pad
......@@ -163,89 +157,42 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
def forward(self,
specgram: Tensor,
loss_name: str = "crossentropy",
mulaw: bool = True,
batched: bool = True,
timesteps: int = 11000,
overlap: int = 550) -> Tensor:
timesteps: int = 100,
overlap: int = 5) -> Tensor:
r"""Inference function for WaveRNN.
Based on the implementation from
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py.
Currently only supports multinomial sampling.
Args:
specgram (Tensor): spectrogram with shape (n_mels, n_time)
loss_name (str): The loss function used to train the WaveRNN model.
Available `loss_name` includes `'mol'` and `'crossentropy'`.
specgram (Tensor): spectrogram of size (n_mels, n_time)
mulaw (bool): Whether to perform mulaw decoding (Default: ``True``).
batched (bool): Whether to perform batch prediction. Using batch prediction
will significantly increase the inference speed (Default: ``True``).
timesteps (int): The time steps for each batch. Only used when `batched`
is set to True (Default: ``11000``).
is set to True (Default: ``100``).
overlap (int): The overlapping time steps between batches. Only used when `batched`
is set to True (Default: ``550``).
is set to True (Default: ``5``).
Returns:
waveform (Tensor): Reconstructed waveform with shape (n_time, ).
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
"""
pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0)
specgram = self._pad_tensor(specgram, pad=pad, side='both')
specgram, aux = self.wavernn_model.upsample(specgram)
if batched:
specgram = self._fold_with_overlap(specgram, timesteps, overlap)
aux = self._fold_with_overlap(aux, timesteps, overlap)
device = specgram.device
dtype = specgram.dtype
# make it compatible with torchscript
n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes))
output: List[Tensor] = []
b_size, _, seq_len = specgram.size()
h1 = torch.zeros((1, b_size, self.wavernn_model.n_rnn), device=device, dtype=dtype)
h2 = torch.zeros((1, b_size, self.wavernn_model.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
d = self.wavernn_model.n_aux
aux_split = [aux[:, d * i:d * (i + 1), :] for i in range(4)]
for i in range(seq_len):
m_t = specgram[:, :, i]
a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
x = torch.cat([x, m_t, a1_t], dim=1)
x = self.wavernn_model.fc(x)
_, h1 = self.wavernn_model.rnn1(x.unsqueeze(1), h1)
x = x + h1[0]
inp = torch.cat([x, a2_t], dim=1)
_, h2 = self.wavernn_model.rnn2(inp.unsqueeze(1), h2)
x = x + h2[0]
x = torch.cat([x, a3_t], dim=1)
x = F.relu(self.wavernn_model.fc1(x))
x = torch.cat([x, a4_t], dim=1)
x = F.relu(self.wavernn_model.fc2(x))
logits = self.wavernn_model.fc3(x)
if loss_name == "crossentropy":
posterior = F.softmax(logits, dim=1)
x = torch.multinomial(posterior, 1).float()
x = bits_to_normalized_waveform(x, n_bits)
output.append(x.squeeze(-1))
else:
raise ValueError(f"Unexpected loss_name: '{loss_name}'. "
f"Valid choices are 'crossentropy'.")
output = torch.stack(output).transpose(0, 1).cpu()
output = self.wavernn_model.infer(specgram).cpu()
if mulaw:
output = normalized_waveform_to_bits(output, n_bits)
......
......@@ -120,6 +120,31 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
def test_infer_waveform(self):
"""Validate the output dimensions of a WaveRNN model's infer method.
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_classes = 512
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
x = torch.rand(n_batch, n_freq, n_time)
out = model.infer(x)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1))
_ConvTasNetParams = namedtuple(
'_ConvTasNetParams',
......
......@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Any
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url
......@@ -347,6 +348,70 @@ class WaveRNN(nn.Module):
# bring back channel dimension
return x.unsqueeze(1)
@torch.jit.export
def infer(self, specgram: Tensor) -> Tensor:
r"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time).
Return:
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time).
1 stands for a single channel.
"""
device = specgram.device
dtype = specgram.dtype
# make it compatible with torchscript
n_bits = int(torch.log2(torch.ones(1) * self.n_classes))
specgram, aux = self.upsample(specgram)
output: List[Tensor] = []
b_size, _, seq_len = specgram.size()
h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
aux_split = [aux[:, self.n_aux * i: self.n_aux * (i + 1), :] for i in range(4)]
for i in range(seq_len):
m_t = specgram[:, :, i]
a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
x = torch.cat([x, m_t, a1_t], dim=1)
x = self.fc(x)
_, h1 = self.rnn1(x.unsqueeze(1), h1)
x = x + h1[0]
inp = torch.cat([x, a2_t], dim=1)
_, h2 = self.rnn2(inp.unsqueeze(1), h2)
x = x + h2[0]
x = torch.cat([x, a3_t], dim=1)
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1)
x = F.relu(self.fc2(x))
logits = self.fc3(x)
posterior = F.softmax(logits, dim=1)
x = torch.multinomial(posterior, 1).float()
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
x = 2 * x / (2 ** n_bits - 1.0) - 1.0
output.append(x)
return torch.stack(output).permute(1, 2, 0)
def wavernn(checkpoint_name: str) -> WaveRNN:
r"""Get pretrained WaveRNN model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment