Unverified Commit 19f53cf2 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor WaveRNNInferenceWrapper (#1845)

parent 635a4a0a
...@@ -29,132 +29,135 @@ from torch import Tensor ...@@ -29,132 +29,135 @@ from torch import Tensor
from processing import normalized_waveform_to_bits from processing import normalized_waveform_to_bits
def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
r'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold().
x = [[h1, h2, ... hn]]
Where each h is a vector of conditioning channels
Eg: timesteps=2, overlap=1 with x.size(1)=10
folded = [[h1, h2, h3, h4],
[h4, h5, h6, h7],
[h7, h8, h9, h10]]
Args:
x (tensor): Upsampled conditioning channels of size (1, timesteps, channel).
timesteps (int): Timesteps for each index of batch.
overlap (int): Timesteps for both xfade and rnn warmup.
Return:
folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
'''
_, channels, total_len = x.size()
# Calculate variables needed
n_folds = (total_len - overlap) // (timesteps + overlap)
extended_len = n_folds * (overlap + timesteps) + overlap
remaining = total_len - extended_len
# Pad if some time steps poking out
if remaining != 0:
n_folds += 1
padding = timesteps + 2 * overlap - remaining
x = _pad_tensor(x, padding, side='after')
folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)
# Get the values for the folded tensor
for i in range(n_folds):
start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap
folded[i] = x[0, :, start:end]
return folded
def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
r'''Applies a crossfade and unfolds into a 1d array.
y = [[seq1],
[seq2],
[seq3]]
Apply a gain envelope at both ends of the sequences
y = [[seq1_in, seq1_timesteps, seq1_out],
[seq2_in, seq2_timesteps, seq2_out],
[seq3_in, seq3_timesteps, seq3_out]]
Stagger and add up the groups of samples:
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
Args:
y (Tensor): Batched sequences of audio samples of size
(num_folds, channels, timesteps + 2 * overlap).
overlap (int): Timesteps for both xfade and rnn warmup.
Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
'''
num_folds, channels, length = y.shape
timesteps = length - 2 * overlap
total_len = num_folds * (timesteps + overlap) + overlap
# Need some silence for the rnn warmup
silence_len = overlap // 2
fade_len = overlap - silence_len
silence = torch.zeros((silence_len), dtype=y.dtype, device=y.device)
linear = torch.ones((silence_len), dtype=y.dtype, device=y.device)
# Equal power crossfade
t = torch.linspace(-1, 1, fade_len, dtype=y.dtype, device=y.device)
fade_in = torch.sqrt(0.5 * (1 + t))
fade_out = torch.sqrt(0.5 * (1 - t))
# Concat the silence to the fades
fade_in = torch.cat([silence, fade_in])
fade_out = torch.cat([linear, fade_out])
# Apply the gain to the overlap samples
y[:, :, :overlap] *= fade_in
y[:, :, -overlap:] *= fade_out
unfolded = torch.zeros((channels, total_len), dtype=y.dtype, device=y.device)
# Loop to add up all the samples
for i in range(num_folds):
start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap
unfolded[:, start:end] += y[i]
return unfolded
def _pad_tensor(x: Tensor, pad: int, side: str = 'both') -> Tensor:
r"""Pad the given tensor.
Args:
x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
Return:
padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
"""
b, c, t = x.size()
total = t + 2 * pad if side == 'both' else t + pad
padded = torch.zeros((b, c, total), device=x.device)
if side == 'before' or side == 'both':
padded[:, :, pad:pad + t] = x
elif side == 'after':
padded[:, :, :t] = x
else:
raise ValueError(f"Unexpected side: '{side}'. "
f"Valid choices are 'both', 'before' and 'after'.")
return padded
class WaveRNNInferenceWrapper(torch.nn.Module): class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN): def __init__(self, wavernn: WaveRNN):
super().__init__() super().__init__()
self.wavernn_model = wavernn self.wavernn_model = wavernn
def _fold_with_overlap(self, x: Tensor, timesteps: int, overlap: int) -> Tensor:
r'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold().
x = [[h1, h2, ... hn]]
Where each h is a vector of conditioning channels
Eg: timesteps=2, overlap=1 with x.size(1)=10
folded = [[h1, h2, h3, h4],
[h4, h5, h6, h7],
[h7, h8, h9, h10]]
Args:
x (tensor): Upsampled conditioning channels of size (1, timesteps, channel).
timesteps (int): Timesteps for each index of batch.
overlap (int): Timesteps for both xfade and rnn warmup.
Return:
folded (tensor): folded tensor of size (n_folds, timesteps + 2 * overlap, channel).
'''
_, channels, total_len = x.size()
# Calculate variables needed
n_folds = (total_len - overlap) // (timesteps + overlap)
extended_len = n_folds * (overlap + timesteps) + overlap
remaining = total_len - extended_len
# Pad if some time steps poking out
if remaining != 0:
n_folds += 1
padding = timesteps + 2 * overlap - remaining
x = self._pad_tensor(x, padding, side='after')
folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)
# Get the values for the folded tensor
for i in range(n_folds):
start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap
folded[i] = x[0, :, start:end]
return folded
def _xfade_and_unfold(self, y: Tensor, overlap: int) -> Tensor:
r'''Applies a crossfade and unfolds into a 1d array.
y = [[seq1],
[seq2],
[seq3]]
Apply a gain envelope at both ends of the sequences
y = [[seq1_in, seq1_timesteps, seq1_out],
[seq2_in, seq2_timesteps, seq2_out],
[seq3_in, seq3_timesteps, seq3_out]]
Stagger and add up the groups of samples:
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
Args:
y (Tensor): Batched sequences of audio samples of size
(num_folds, channels, timesteps + 2 * overlap).
overlap (int): Timesteps for both xfade and rnn warmup.
Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor of size (channels, total_len).
'''
num_folds, channels, length = y.shape
timesteps = length - 2 * overlap
total_len = num_folds * (timesteps + overlap) + overlap
# Need some silence for the rnn warmup
silence_len = overlap // 2
fade_len = overlap - silence_len
silence = torch.zeros((silence_len), dtype=y.dtype, device=y.device)
linear = torch.ones((silence_len), dtype=y.dtype, device=y.device)
# Equal power crossfade
t = torch.linspace(-1, 1, fade_len, dtype=y.dtype, device=y.device)
fade_in = torch.sqrt(0.5 * (1 + t))
fade_out = torch.sqrt(0.5 * (1 - t))
# Concat the silence to the fades
fade_in = torch.cat([silence, fade_in])
fade_out = torch.cat([linear, fade_out])
# Apply the gain to the overlap samples
y[:, :, :overlap] *= fade_in
y[:, :, -overlap:] *= fade_out
unfolded = torch.zeros((channels, total_len), dtype=y.dtype, device=y.device)
# Loop to add up all the samples
for i in range(num_folds):
start = i * (timesteps + overlap)
end = start + timesteps + 2 * overlap
unfolded[:, start:end] += y[i]
return unfolded
def _pad_tensor(self, x: Tensor, pad: int, side: str = 'both') -> Tensor:
r"""Pad the given tensor.
Args:
x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
Return:
padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
"""
b, c, t = x.size()
total = t + 2 * pad if side == 'both' else t + pad
padded = torch.zeros((b, c, total), device=x.device)
if side == 'before' or side == 'both':
padded[:, :, pad:pad + t] = x
elif side == 'after':
padded[:, :, :t] = x
else:
raise ValueError(f"Unexpected side: '{side}'. "
f"Valid choices are 'both', 'before' and 'after'.")
return padded
def forward(self, def forward(self,
specgram: Tensor, specgram: Tensor,
mulaw: bool = True, mulaw: bool = True,
...@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad = (self.wavernn_model.kernel_size - 1) // 2 pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0) specgram = specgram.unsqueeze(0)
specgram = self._pad_tensor(specgram, pad=pad, side='both') specgram = _pad_tensor(specgram, pad=pad, side='both')
if batched: if batched:
specgram = self._fold_with_overlap(specgram, timesteps, overlap) specgram = _fold_with_overlap(specgram, timesteps, overlap)
n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes)) n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes))
...@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes) output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes)
if batched: if batched:
output = self._xfade_and_unfold(output, overlap) output = _xfade_and_unfold(output, overlap)
else: else:
output = output[0] output = output[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment