Unverified Commit 19f53cf2 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor WaveRNNInferenceWrapper (#1845)

parent 635a4a0a
......@@ -29,13 +29,7 @@ from torch import Tensor
from processing import normalized_waveform_to_bits
class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN):
super().__init__()
self.wavernn_model = wavernn
def _fold_with_overlap(self, x: Tensor, timesteps: int, overlap: int) -> Tensor:
def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
r'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold().
......@@ -66,7 +60,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
if remaining != 0:
n_folds += 1
padding = timesteps + 2 * overlap - remaining
x = self._pad_tensor(x, padding, side='after')
x = _pad_tensor(x, padding, side='after')
folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)
......@@ -78,7 +72,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return folded
def _xfade_and_unfold(self, y: Tensor, overlap: int) -> Tensor:
def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
r'''Applies a crossfade and unfolds into a 1d array.
y = [[seq1],
......@@ -133,7 +128,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return unfolded
def _pad_tensor(self, x: Tensor, pad: int, side: str = 'both') -> Tensor:
def _pad_tensor(x: Tensor, pad: int, side: str = 'both') -> Tensor:
r"""Pad the given tensor.
Args:
......@@ -155,6 +151,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
f"Valid choices are 'both', 'before' and 'after'.")
return padded
class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN):
super().__init__()
self.wavernn_model = wavernn
def forward(self,
specgram: Tensor,
mulaw: bool = True,
......@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0)
specgram = self._pad_tensor(specgram, pad=pad, side='both')
specgram = _pad_tensor(specgram, pad=pad, side='both')
if batched:
specgram = self._fold_with_overlap(specgram, timesteps, overlap)
specgram = _fold_with_overlap(specgram, timesteps, overlap)
n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes))
......@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes)
if batched:
output = self._xfade_and_unfold(output, overlap)
output = _xfade_and_unfold(output, overlap)
else:
output = output[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment