"...kompute/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "ecd2f176277db4f074e25a2c3646b04b51cec119"
Unverified Commit 19f53cf2 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor WaveRNNInferenceWrapper (#1845)

parent 635a4a0a
...@@ -29,13 +29,7 @@ from torch import Tensor ...@@ -29,13 +29,7 @@ from torch import Tensor
from processing import normalized_waveform_to_bits from processing import normalized_waveform_to_bits
class WaveRNNInferenceWrapper(torch.nn.Module): def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
def __init__(self, wavernn: WaveRNN):
super().__init__()
self.wavernn_model = wavernn
def _fold_with_overlap(self, x: Tensor, timesteps: int, overlap: int) -> Tensor:
r'''Fold the tensor with overlap for quick batched inference. r'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold(). Overlap will be used for crossfading in xfade_and_unfold().
...@@ -66,7 +60,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -66,7 +60,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
if remaining != 0: if remaining != 0:
n_folds += 1 n_folds += 1
padding = timesteps + 2 * overlap - remaining padding = timesteps + 2 * overlap - remaining
x = self._pad_tensor(x, padding, side='after') x = _pad_tensor(x, padding, side='after')
folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device) folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)
...@@ -78,7 +72,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -78,7 +72,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return folded return folded
def _xfade_and_unfold(self, y: Tensor, overlap: int) -> Tensor:
def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
r'''Applies a crossfade and unfolds into a 1d array. r'''Applies a crossfade and unfolds into a 1d array.
y = [[seq1], y = [[seq1],
...@@ -133,7 +128,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -133,7 +128,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return unfolded return unfolded
def _pad_tensor(self, x: Tensor, pad: int, side: str = 'both') -> Tensor:
def _pad_tensor(x: Tensor, pad: int, side: str = 'both') -> Tensor:
r"""Pad the given tensor. r"""Pad the given tensor.
Args: Args:
...@@ -155,6 +151,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -155,6 +151,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
f"Valid choices are 'both', 'before' and 'after'.") f"Valid choices are 'both', 'before' and 'after'.")
return padded return padded
class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN):
super().__init__()
self.wavernn_model = wavernn
def forward(self, def forward(self,
specgram: Tensor, specgram: Tensor,
mulaw: bool = True, mulaw: bool = True,
...@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad = (self.wavernn_model.kernel_size - 1) // 2 pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0) specgram = specgram.unsqueeze(0)
specgram = self._pad_tensor(specgram, pad=pad, side='both') specgram = _pad_tensor(specgram, pad=pad, side='both')
if batched: if batched:
specgram = self._fold_with_overlap(specgram, timesteps, overlap) specgram = _fold_with_overlap(specgram, timesteps, overlap)
n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes)) n_bits = int(torch.log2(torch.ones(1) * self.wavernn_model.n_classes))
...@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes) output = torchaudio.functional.mu_law_decoding(output, self.wavernn_model.n_classes)
if batched: if batched:
output = self._xfade_and_unfold(output, overlap) output = _xfade_and_unfold(output, overlap)
else: else:
output = output[0] output = output[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment