Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
6321adcf
"graphbolt/vscode:/vscode.git/clone" did not exist on "a2234d60752631d92f46fed9d8be1612c4acbbfd"
Commit
6321adcf
authored
Oct 10, 2021
by
moto
Browse files
Replace custom padding with torch's native impl (#1846)
parent
498722b5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
25 deletions
+2
-25
examples/pipeline_wavernn/wavernn_inference_wrapper.py
examples/pipeline_wavernn/wavernn_inference_wrapper.py
+2
-25
No files found.
examples/pipeline_wavernn/wavernn_inference_wrapper.py
View file @
6321adcf
...
@@ -60,7 +60,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
...
@@ -60,7 +60,7 @@ def _fold_with_overlap(x: Tensor, timesteps: int, overlap: int) -> Tensor:
if
remaining
!=
0
:
if
remaining
!=
0
:
n_folds
+=
1
n_folds
+=
1
padding
=
timesteps
+
2
*
overlap
-
remaining
padding
=
timesteps
+
2
*
overlap
-
remaining
x
=
_pad_tensor
(
x
,
padding
,
side
=
'after'
)
x
=
torch
.
nn
.
functional
.
pad
(
x
,
(
0
,
padding
)
)
folded
=
torch
.
zeros
((
n_folds
,
channels
,
timesteps
+
2
*
overlap
),
device
=
x
.
device
)
folded
=
torch
.
zeros
((
n_folds
,
channels
,
timesteps
+
2
*
overlap
),
device
=
x
.
device
)
...
@@ -129,29 +129,6 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
...
@@ -129,29 +129,6 @@ def _xfade_and_unfold(y: Tensor, overlap: int) -> Tensor:
return
unfolded
return
unfolded
def
_pad_tensor
(
x
:
Tensor
,
pad
:
int
,
side
:
str
=
'both'
)
->
Tensor
:
r
"""Pad the given tensor.
Args:
x (Tensor): The tensor to pad of size (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
Return:
padded (Tensor): The padded tensor of size (n_batch, n_mels, time).
"""
b
,
c
,
t
=
x
.
size
()
total
=
t
+
2
*
pad
if
side
==
'both'
else
t
+
pad
padded
=
torch
.
zeros
((
b
,
c
,
total
),
device
=
x
.
device
)
if
side
==
'before'
or
side
==
'both'
:
padded
[:,
:,
pad
:
pad
+
t
]
=
x
elif
side
==
'after'
:
padded
[:,
:,
:
t
]
=
x
else
:
raise
ValueError
(
f
"Unexpected side: '
{
side
}
'. "
f
"Valid choices are 'both', 'before' and 'after'."
)
return
padded
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
wavernn
:
WaveRNN
):
def
__init__
(
self
,
wavernn
:
WaveRNN
):
...
@@ -189,7 +166,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -189,7 +166,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
specgram
=
torch
.
nn
.
functional
.
pad
(
specgram
,
(
pad
,
pad
)
)
if
batched
:
if
batched
:
specgram
=
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
specgram
=
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment