Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
202bc4f2
You need to sign in or sign up before continuing.
Commit
202bc4f2
authored
Oct 08, 2021
by
moto
Browse files
Refactor WaveRNNInferenceWrapper (#1845)
parent
9f9b6537
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
126 additions
and
123 deletions
+126
-123
examples/pipeline_wavernn/wavernn_inference_wrapper.py
examples/pipeline_wavernn/wavernn_inference_wrapper.py
+126
-123
No files found.
examples/pipeline_wavernn/wavernn_inference_wrapper.py
View file @
202bc4f2
...
@@ -29,13 +29,7 @@ from torch import Tensor
...
@@ -29,13 +29,7 @@ from torch import Tensor
from
processing
import
normalized_waveform_to_bits
from
processing
import
normalized_waveform_to_bits
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
def
_fold_with_overlap
(
x
:
Tensor
,
timesteps
:
int
,
overlap
:
int
)
->
Tensor
:
def
__init__
(
self
,
wavernn
:
WaveRNN
):
super
().
__init__
()
self
.
wavernn_model
=
wavernn
def
_fold_with_overlap
(
self
,
x
:
Tensor
,
timesteps
:
int
,
overlap
:
int
)
->
Tensor
:
r
'''Fold the tensor with overlap for quick batched inference.
r
'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold().
Overlap will be used for crossfading in xfade_and_unfold().
...
@@ -66,7 +60,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -66,7 +60,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
if
remaining
!=
0
:
if
remaining
!=
0
:
n_folds
+=
1
n_folds
+=
1
padding
=
timesteps
+
2
*
overlap
-
remaining
padding
=
timesteps
+
2
*
overlap
-
remaining
x
=
self
.
_pad_tensor
(
x
,
padding
,
side
=
'after'
)
x
=
_pad_tensor
(
x
,
padding
,
side
=
'after'
)
folded
=
torch
.
zeros
((
n_folds
,
channels
,
timesteps
+
2
*
overlap
),
device
=
x
.
device
)
folded
=
torch
.
zeros
((
n_folds
,
channels
,
timesteps
+
2
*
overlap
),
device
=
x
.
device
)
...
@@ -78,7 +72,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -78,7 +72,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return
folded
return
folded
def
_xfade_and_unfold
(
self
,
y
:
Tensor
,
overlap
:
int
)
->
Tensor
:
def
_xfade_and_unfold
(
y
:
Tensor
,
overlap
:
int
)
->
Tensor
:
r
'''Applies a crossfade and unfolds into a 1d array.
r
'''Applies a crossfade and unfolds into a 1d array.
y = [[seq1],
y = [[seq1],
...
@@ -133,7 +128,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -133,7 +128,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return
unfolded
return
unfolded
def
_pad_tensor
(
self
,
x
:
Tensor
,
pad
:
int
,
side
:
str
=
'both'
)
->
Tensor
:
def
_pad_tensor
(
x
:
Tensor
,
pad
:
int
,
side
:
str
=
'both'
)
->
Tensor
:
r
"""Pad the given tensor.
r
"""Pad the given tensor.
Args:
Args:
...
@@ -155,6 +151,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -155,6 +151,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
f
"Valid choices are 'both', 'before' and 'after'."
)
f
"Valid choices are 'both', 'before' and 'after'."
)
return
padded
return
padded
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
wavernn
:
WaveRNN
):
super
().
__init__
()
self
.
wavernn_model
=
wavernn
def
forward
(
self
,
def
forward
(
self
,
specgram
:
Tensor
,
specgram
:
Tensor
,
mulaw
:
bool
=
True
,
mulaw
:
bool
=
True
,
...
@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
self
.
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
specgram
=
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
if
batched
:
if
batched
:
specgram
=
self
.
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
specgram
=
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
wavernn_model
.
n_classes
))
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
wavernn_model
.
n_classes
))
...
@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
output
=
torchaudio
.
functional
.
mu_law_decoding
(
output
,
self
.
wavernn_model
.
n_classes
)
output
=
torchaudio
.
functional
.
mu_law_decoding
(
output
,
self
.
wavernn_model
.
n_classes
)
if
batched
:
if
batched
:
output
=
self
.
_xfade_and_unfold
(
output
,
overlap
)
output
=
_xfade_and_unfold
(
output
,
overlap
)
else
:
else
:
output
=
output
[
0
]
output
=
output
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment