Unverified Commit d93322e8 authored by moto's avatar moto Committed by GitHub
Browse files

Replace custom padding with torch's native impl (#1846)

parent 9637c6bf
......@@ -60,7 +60,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
if remaining != 0:
n_folds += 1
padding = timesteps + 2 * overlap - remaining
x = _pad_tensor(x, padding, side='after')
x = torch.nn.functional.pad(x, (0, padding))
folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)
......@@ -129,29 +129,6 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
return unfolded
def _pad_tensor(x: Tensor, pad: int, side: str = 'both') -> Tensor:
r"""Pad the given tensor.
Args:
x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
Return:
padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
"""
b, c, t = x.size()
total = t + 2 * pad if side == 'both' else t + pad
padded = torch.zeros((b, c, total), device=x.device)
if side == 'before' or side == 'both':
padded[:, :, pad:pad + t] = x
elif side == 'after':
padded[:, :, :t] = x
else:
raise ValueError(f"Unexpected side: '{side}'. "
f"Valid choices are 'both', 'before' and 'after'.")
return padded
class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN):
......@@ -189,7 +166,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0)
specgram = _pad_tensor(specgram, pad=pad, side='both')
specgram = torch.nn.functional.pad(specgram, (pad, pad))
if batched:
specgram = _fold_with_overlap(specgram, timesteps, overlap)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment