"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "bae519ce1b351eb88efe5b248753fd8c59ac6203"
Unverified Commit 3bb5feb5 authored by yangarbiter's avatar yangarbiter Committed by GitHub
Browse files

Refactor WaveRNN infer and move it to the codebase (#1704)

parent 63f0614b
...@@ -106,6 +106,8 @@ WaveRNN ...@@ -106,6 +106,8 @@ WaveRNN
.. automethod:: forward .. automethod:: forward
.. automethod:: infer
Factory Functions Factory Functions
----------------- -----------------
......
...@@ -21,10 +21,6 @@ def parse_args(): ...@@ -21,10 +21,6 @@ def parse_args():
"--jit", default=False, action="store_true", "--jit", default=False, action="store_true",
help="If used, the model and inference function is jitted." help="If used, the model and inference function is jitted."
) )
parser.add_argument(
"--loss", default="crossentropy", choices=["crossentropy"],
type=str, help="The type of loss the pretrained model is trained on.",
)
parser.add_argument( parser.add_argument(
"--no-batch-inference", default=False, action="store_true", "--no-batch-inference", default=False, action="store_true",
help="Don't use batch inference." help="Don't use batch inference."
...@@ -39,11 +35,11 @@ def parse_args(): ...@@ -39,11 +35,11 @@ def parse_args():
help="Select the WaveRNN checkpoint." help="Select the WaveRNN checkpoint."
) )
parser.add_argument( parser.add_argument(
"--batch-timesteps", default=11000, type=int, "--batch-timesteps", default=100, type=int,
help="The time steps for each batch. Only used when batch inference is used", help="The time steps for each batch. Only used when batch inference is used",
) )
parser.add_argument( parser.add_argument(
"--batch-overlap", default=550, type=int, "--batch-overlap", default=5, type=int,
help="The overlapping time steps between batches. Only used when batch inference is used", help="The overlapping time steps between batches. Only used when batch inference is used",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -79,13 +75,12 @@ def main(args): ...@@ -79,13 +75,12 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
output = wavernn_inference_model(mel_specgram.to(device), output = wavernn_inference_model(mel_specgram.to(device),
loss_name=args.loss,
mulaw=(not args.no_mulaw), mulaw=(not args.no_mulaw),
batched=(not args.no_batch_inference), batched=(not args.no_batch_inference),
timesteps=args.batch_timesteps, timesteps=args.batch_timesteps,
overlap=args.batch_overlap,) overlap=args.batch_overlap,)
torchaudio.save(args.output_wav_path, output.reshape(1, -1), sample_rate=sample_rate) torchaudio.save(args.output_wav_path, output, sample_rate=sample_rate)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,18 +21,12 @@ ...@@ -21,18 +21,12 @@
# ***************************************************************************** # *****************************************************************************
from typing import List
from torchaudio.models.wavernn import WaveRNN from torchaudio.models.wavernn import WaveRNN
import torch import torch
import torch.nn.functional as F
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from processing import ( from processing import normalized_waveform_to_bits
normalized_waveform_to_bits,
bits_to_normalized_waveform,
)
class WaveRNNInferenceWrapper(torch.nn.Module): class WaveRNNInferenceWrapper(torch.nn.Module):
...@@ -53,12 +47,12 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -53,12 +47,12 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[h7, h8, h9, h10]] [h7, h8, h9, h10]]
Args: Args:
x (tensor): Upsampled conditioning channels with shape (1, timesteps, channel). x (tensor): Upsampled conditioning channels of size (1, timesteps, channel).
timesteps (int): Timesteps for each index of batch. timesteps (int): Timesteps for each index of batch.
overlap (int): Timesteps for both xfade and rnn warmup. overlap (int): Timesteps for both xfade and rnn warmup.
Return: Return:
folded (tensor): folded tensor with shape (n_folds, timesteps + 2 * overlap, channel). folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
''' '''
_, channels, total_len = x.size() _, channels, total_len = x.size()
...@@ -98,15 +92,15 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -98,15 +92,15 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...] [seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
Args: Args:
y (Tensor): Batched sequences of audio samples with shape y (Tensor): Batched sequences of audio samples of size
(num_folds, timesteps + 2 * overlap). (num_folds, channels, timesteps + 2 * overlap).
overlap (int): Timesteps for both xfade and rnn warmup. overlap (int): Timesteps for both xfade and rnn warmup.
Returns: Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor with shape (total_len). unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
''' '''
num_folds, length = y.shape num_folds, channels, length = y.shape
timesteps = length - 2 * overlap timesteps = length - 2 * overlap
total_len = num_folds * (timesteps + overlap) + overlap total_len = num_folds * (timesteps + overlap) + overlap
...@@ -126,16 +120,16 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -126,16 +120,16 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
fade_out = torch.cat([linear, fade_out]) fade_out = torch.cat([linear, fade_out])
# Apply the gain to the overlap samples # Apply the gain to the overlap samples
y[:, :overlap] *= fade_in y[:, :, :overlap] *= fade_in
y[:, -overlap:] *= fade_out y[:, :, -overlap:] *= fade_out
unfolded = torch.zeros((total_len), dtype=y.dtype, device=y.device) unfolded = torch.zeros((channels, total_len), dtype=y.dtype, device=y.device)
# Loop to add up all the samples # Loop to add up all the samples
for i in range(num_folds): for i in range(num_folds):
start = i * (timesteps + overlap) start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap end = start + timesteps + 2 * overlap
unfolded[start:end] += y[i] unfolded[:, start:end] += y[i]
return unfolded return unfolded
...@@ -143,11 +137,11 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -143,11 +137,11 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
r"""Pad the given tensor. r"""Pad the given tensor.
Args: Args:
x (Tensor): The tensor to pad with shape (n_batch, n_mels, time). x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input. pad (int): The amount of padding applied to the input.
Return: Return:
padded (Tensor): The padded tensor with shape (n_batch, n_mels, time). padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
""" """
b, c, t = x.size() b, c, t = x.size()
total = t + 2 * pad if side == 'both' else t + pad total = t + 2 * pad if side == 'both' else t + pad
...@@ -163,89 +157,42 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -163,89 +157,42 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
def forward(self, def forward(self,
specgram: Tensor, specgram: Tensor,
loss_name: str = "crossentropy",
mulaw: bool = True, mulaw: bool = True,
batched: bool = True, batched: bool = True,
timesteps: int = 11000, timesteps: int = 100,
overlap: int = 550) -> Tensor: overlap: int = 5) -> Tensor:
r"""Inference function for WaveRNN. r"""Inference function for WaveRNN.
Based on the implementation from Based on the implementation from
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py. https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py.
Currently only supports multinomial sampling.
Args: Args:
specgram (Tensor): spectrogram with shape (n_mels, n_time) specgram (Tensor): spectrogram of size (n_mels, n_time)
loss_name (str): The loss function used to train the WaveRNN model.
Available `loss_name` includes `'mol'` and `'crossentropy'`.
mulaw (bool): Whether to perform mulaw decoding (Default: ``True``). mulaw (bool): Whether to perform mulaw decoding (Default: ``True``).
batched (bool): Whether to perform batch prediction. Using batch prediction batched (bool): Whether to perform batch prediction. Using batch prediction
will significantly increase the inference speed (Default: ``True``). will significantly increase the inference speed (Default: ``True``).
timesteps (int): The time steps for each batch. Only used when `batched` timesteps (int): The time steps for each batch. Only used when `batched`
is set to True (Default: ``11000``). is set to True (Default: ``100``).
overlap (int): The overlapping time steps between batches. Only used when `batched` overlap (int): The overlapping time steps between batches. Only used when `batched`
is set to True (Default: ``550``). is set to True (Default: ``5``).
Returns: Returns:
waveform (Tensor): Reconstructed waveform with shape (n_time, ). waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
""" """
pad = (self.wavernn_model.kernel_size - 1) // 2 pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0) specgram = specgram.unsqueeze(0)
specgram = self._pad_tensor(specgram, pad=pad, side='both') specgram = self._pad_tensor(specgram, pad=pad, side='both')
specgram, aux = self.wavernn_model.upsample(specgram)
if batched: if batched:
specgram = self._fold_with_overlap(specgram, timesteps, overlap) specgram = self._fold_with_overlap(specgram, timesteps, overlap)
aux = self._fold_with_overlap(aux, timesteps, overlap)
device = specgram.device
dtype = specgram.dtype
# make it compatible with torchscript
n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes)) n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes))
output: List[Tensor] = []
b_size, _, seq_len = specgram.size()
h1 = torch.zeros((1, b_size, self.wavernn_model.n_rnn), device=device, dtype=dtype)
h2 = torch.zeros((1, b_size, self.wavernn_model.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
d = self.wavernn_model.n_aux
aux_split = [aux[:, d * i:d * (i + 1), :] for i in range(4)]
for i in range(seq_len):
m_t = specgram[:, :, i]
a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
x = torch.cat([x, m_t, a1_t], dim=1)
x = self.wavernn_model.fc(x)
_, h1 = self.wavernn_model.rnn1(x.unsqueeze(1), h1)
x = x + h1[0]
inp = torch.cat([x, a2_t], dim=1)
_, h2 = self.wavernn_model.rnn2(inp.unsqueeze(1), h2)
x = x + h2[0]
x = torch.cat([x, a3_t], dim=1)
x = F.relu(self.wavernn_model.fc1(x))
x = torch.cat([x, a4_t], dim=1)
x = F.relu(self.wavernn_model.fc2(x))
logits = self.wavernn_model.fc3(x)
if loss_name == "crossentropy":
posterior = F.softmax(logits, dim=1)
x = torch.multinomial(posterior, 1).float()
x = bits_to_normalized_waveform(x, n_bits)
output.append(x.squeeze(-1))
else:
raise ValueError(f"Unexpected loss_name: '{loss_name}'. "
f"Valid choices are 'crossentropy'.")
output = torch.stack(output).transpose(0, 1).cpu() output = self.wavernn_model.infer(specgram).cpu()
if mulaw: if mulaw:
output = normalized_waveform_to_bits(output, n_bits) output = normalized_waveform_to_bits(output, n_bits)
......
...@@ -120,6 +120,31 @@ class TestWaveRNN(common_utils.TorchaudioTestCase): ...@@ -120,6 +120,31 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes) assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1), n_classes)
def test_infer_waveform(self):
"""Validate the output dimensions of a WaveRNN model's infer method.
"""
upsample_scales = [5, 5, 8]
n_rnn = 512
n_fc = 512
n_classes = 512
hop_length = 200
n_batch = 2
n_time = 200
n_freq = 100
n_output = 256
n_res_block = 10
n_hidden = 128
kernel_size = 5
model = WaveRNN(upsample_scales, n_classes, hop_length, n_res_block,
n_rnn, n_fc, kernel_size, n_freq, n_hidden, n_output)
x = torch.rand(n_batch, n_freq, n_time)
out = model.infer(x)
assert out.size() == (n_batch, 1, hop_length * (n_time - kernel_size + 1))
_ConvTasNetParams = namedtuple( _ConvTasNetParams = namedtuple(
'_ConvTasNetParams', '_ConvTasNetParams',
......
...@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Any ...@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Any
import torch import torch
from torch import Tensor from torch import Tensor
from torch import nn from torch import nn
import torch.nn.functional as F
from torch.hub import load_state_dict_from_url from torch.hub import load_state_dict_from_url
...@@ -347,6 +348,70 @@ class WaveRNN(nn.Module): ...@@ -347,6 +348,70 @@ class WaveRNN(nn.Module):
# bring back channel dimension # bring back channel dimension
return x.unsqueeze(1) return x.unsqueeze(1)
@torch.jit.export
def infer(self, specgram: Tensor) -> Tensor:
r"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time).
Return:
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time).
1 stands for a single channel.
"""
device = specgram.device
dtype = specgram.dtype
# make it compatible with torchscript
n_bits = int(torch.log2(torch.ones(1) * self.n_classes))
specgram, aux = self.upsample(specgram)
output: List[Tensor] = []
b_size, _, seq_len = specgram.size()
h1 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
h2 = torch.zeros((1, b_size, self.n_rnn), device=device, dtype=dtype)
x = torch.zeros((b_size, 1), device=device, dtype=dtype)
aux_split = [aux[:, self.n_aux * i: self.n_aux * (i + 1), :] for i in range(4)]
for i in range(seq_len):
m_t = specgram[:, :, i]
a1_t, a2_t, a3_t, a4_t = [a[:, :, i] for a in aux_split]
x = torch.cat([x, m_t, a1_t], dim=1)
x = self.fc(x)
_, h1 = self.rnn1(x.unsqueeze(1), h1)
x = x + h1[0]
inp = torch.cat([x, a2_t], dim=1)
_, h2 = self.rnn2(inp.unsqueeze(1), h2)
x = x + h2[0]
x = torch.cat([x, a3_t], dim=1)
x = F.relu(self.fc1(x))
x = torch.cat([x, a4_t], dim=1)
x = F.relu(self.fc2(x))
logits = self.fc3(x)
posterior = F.softmax(logits, dim=1)
x = torch.multinomial(posterior, 1).float()
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
x = 2 * x / (2 ** n_bits - 1.0) - 1.0
output.append(x)
return torch.stack(output).permute(1, 2, 0)
def wavernn(checkpoint_name: str) -> WaveRNN: def wavernn(checkpoint_name: str) -> WaveRNN:
r"""Get pretrained WaveRNN model. r"""Get pretrained WaveRNN model.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment