"graphbolt/vscode:/vscode.git/clone" did not exist on "a2234d60752631d92f46fed9d8be1612c4acbbfd"
Commit 6321adcf authored by moto's avatar moto
Browse files

Replace custom padding with torch's native impl (#1846)

parent 498722b5
...@@ -60,7 +60,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor: ...@@ -60,7 +60,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
if remaining != 0: if remaining != 0:
n_folds += 1 n_folds += 1
padding = timesteps + 2 * overlap - remaining padding = timesteps + 2 * overlap - remaining
x = _pad_tensor(x, padding, side='after') x = torch.nn.functional.pad(x, (0, padding))
folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device) folded = torch.zeros((n_folds, channels, timesteps + 2 * overlap), device=x.device)
...@@ -129,29 +129,6 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor: ...@@ -129,29 +129,6 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
return unfolded return unfolded
def _pad_tensor(x: Tensor, pad: int, side: str = 'both') -> Tensor:
r"""Pad the given tensor.
Args:
x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
Return:
padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
"""
b, c, t = x.size()
total = t + 2 * pad if side == 'both' else t + pad
padded = torch.zeros((b, c, total), device=x.device)
if side == 'before' or side == 'both':
padded[:, :, pad:pad + t] = x
elif side == 'after':
padded[:, :, :t] = x
else:
raise ValueError(f"Unexpected side: '{side}'. "
f"Valid choices are 'both', 'before' and 'after'.")
return padded
class WaveRNNInferenceWrapper(torch.nn.Module): class WaveRNNInferenceWrapper(torch.nn.Module):
def __init__(self, wavernn: WaveRNN): def __init__(self, wavernn: WaveRNN):
...@@ -189,7 +166,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module): ...@@ -189,7 +166,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad = (self.wavernn_model.kernel_size - 1) // 2 pad = (self.wavernn_model.kernel_size - 1) // 2
specgram = specgram.unsqueeze(0) specgram = specgram.unsqueeze(0)
specgram = _pad_tensor(specgram, pad=pad, side='both') specgram = torch.nn.functional.pad(specgram, (pad, pad))
if batched: if batched:
specgram = _fold_with_overlap(specgram, timesteps, overlap) specgram = _fold_with_overlap(specgram, timesteps, overlap)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment