Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
19f53cf2
"...kompute/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "ecd2f176277db4f074e25a2c3646b04b51cec119"
Unverified
Commit
19f53cf2
authored
Oct 08, 2021
by
moto
Committed by
GitHub
Oct 08, 2021
Browse files
Refactor WaveRNNInferenceWrapper (#1845)
parent
635a4a0a
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
126 additions
and
123 deletions
+126
-123
examples/pipeline_wavernn/wavernn_inference_wrapper.py
examples/pipeline_wavernn/wavernn_inference_wrapper.py
+126
-123
No files found.
examples/pipeline_wavernn/wavernn_inference_wrapper.py
View file @
19f53cf2
...
@@ -29,13 +29,7 @@ from torch import Tensor
...
@@ -29,13 +29,7 @@ from torch import Tensor
from
processing
import
normalized_waveform_to_bits
from
processing
import
normalized_waveform_to_bits
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
def
_fold_with_overlap
(
x
:
Tensor
,
timesteps
:
int
,
overlap
:
int
)
->
Tensor
:
def
__init__
(
self
,
wavernn
:
WaveRNN
):
super
().
__init__
()
self
.
wavernn_model
=
wavernn
def
_fold_with_overlap
(
self
,
x
:
Tensor
,
timesteps
:
int
,
overlap
:
int
)
->
Tensor
:
r
'''Fold the tensor with overlap for quick batched inference.
r
'''Fold the tensor with overlap for quick batched inference.
Overlap will be used for crossfading in xfade_and_unfold().
Overlap will be used for crossfading in xfade_and_unfold().
...
@@ -66,7 +60,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -66,7 +60,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
if
remaining
!=
0
:
if
remaining
!=
0
:
n_folds
+=
1
n_folds
+=
1
padding
=
timesteps
+
2
*
overlap
-
remaining
padding
=
timesteps
+
2
*
overlap
-
remaining
x
=
self
.
_pad_tensor
(
x
,
padding
,
side
=
'after'
)
x
=
_pad_tensor
(
x
,
padding
,
side
=
'after'
)
folded
=
torch
.
zeros
((
n_folds
,
channels
,
timesteps
+
2
*
overlap
),
device
=
x
.
device
)
folded
=
torch
.
zeros
((
n_folds
,
channels
,
timesteps
+
2
*
overlap
),
device
=
x
.
device
)
...
@@ -78,7 +72,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -78,7 +72,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return
folded
return
folded
def
_xfade_and_unfold
(
self
,
y
:
Tensor
,
overlap
:
int
)
->
Tensor
:
def
_xfade_and_unfold
(
y
:
Tensor
,
overlap
:
int
)
->
Tensor
:
r
'''Applies a crossfade and unfolds into a 1d array.
r
'''Applies a crossfade and unfolds into a 1d array.
y = [[seq1],
y = [[seq1],
...
@@ -133,7 +128,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -133,7 +128,8 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
return
unfolded
return
unfolded
def
_pad_tensor
(
self
,
x
:
Tensor
,
pad
:
int
,
side
:
str
=
'both'
)
->
Tensor
:
def
_pad_tensor
(
x
:
Tensor
,
pad
:
int
,
side
:
str
=
'both'
)
->
Tensor
:
r
"""Pad the given tensor.
r
"""Pad the given tensor.
Args:
Args:
...
@@ -155,6 +151,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -155,6 +151,13 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
f
"Valid choices are 'both', 'before' and 'after'."
)
f
"Valid choices are 'both', 'before' and 'after'."
)
return
padded
return
padded
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
wavernn
:
WaveRNN
):
super
().
__init__
()
self
.
wavernn_model
=
wavernn
def
forward
(
self
,
def
forward
(
self
,
specgram
:
Tensor
,
specgram
:
Tensor
,
mulaw
:
bool
=
True
,
mulaw
:
bool
=
True
,
...
@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -186,9 +189,9 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
self
.
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
specgram
=
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
if
batched
:
if
batched
:
specgram
=
self
.
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
specgram
=
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
wavernn_model
.
n_classes
))
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
wavernn_model
.
n_classes
))
...
@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -199,7 +202,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
output
=
torchaudio
.
functional
.
mu_law_decoding
(
output
,
self
.
wavernn_model
.
n_classes
)
output
=
torchaudio
.
functional
.
mu_law_decoding
(
output
,
self
.
wavernn_model
.
n_classes
)
if
batched
:
if
batched
:
output
=
self
.
_xfade_and_unfold
(
output
,
overlap
)
output
=
_xfade_and_unfold
(
output
,
overlap
)
else
:
else
:
output
=
output
[
0
]
output
=
output
[
0
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment