Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
3bb5feb5
Unverified
Commit
3bb5feb5
authored
Aug 23, 2021
by
yangarbiter
Committed by
GitHub
Aug 23, 2021
Browse files
Refactor WaveRNN infer and move it to the codebase (#1704)
parent
63f0614b
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
119 additions
and
85 deletions
+119
-85
docs/source/models.rst
docs/source/models.rst
+2
-0
examples/pipeline_wavernn/inference.py
examples/pipeline_wavernn/inference.py
+3
-8
examples/pipeline_wavernn/wavernn_inference_wrapper.py
examples/pipeline_wavernn/wavernn_inference_wrapper.py
+24
-77
test/torchaudio_unittest/models/models_test.py
test/torchaudio_unittest/models/models_test.py
+25
-0
torchaudio/models/wavernn.py
torchaudio/models/wavernn.py
+65
-0
No files found.
docs/source/models.rst
View file @
3bb5feb5
...
...
@@ -106,6 +106,8 @@ WaveRNN
.. automethod:: forward
.. automethod:: infer
Factory Functions
-----------------
...
...
examples/pipeline_wavernn/inference.py
View file @
3bb5feb5
...
...
@@ -21,10 +21,6 @@ def parse_args():
"--jit"
,
default
=
False
,
action
=
"store_true"
,
help
=
"If used, the model and inference function is jitted."
)
parser
.
add_argument
(
"--loss"
,
default
=
"crossentropy"
,
choices
=
[
"crossentropy"
],
type
=
str
,
help
=
"The type of loss the pretrained model is trained on."
,
)
parser
.
add_argument
(
"--no-batch-inference"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Don't use batch inference."
...
...
@@ -39,11 +35,11 @@ def parse_args():
help
=
"Select the WaveRNN checkpoint."
)
parser
.
add_argument
(
"--batch-timesteps"
,
default
=
1
10
00
,
type
=
int
,
"--batch-timesteps"
,
default
=
100
,
type
=
int
,
help
=
"The time steps for each batch. Only used when batch inference is used"
,
)
parser
.
add_argument
(
"--batch-overlap"
,
default
=
5
50
,
type
=
int
,
"--batch-overlap"
,
default
=
5
,
type
=
int
,
help
=
"The overlapping time steps between batches. Only used when batch inference is used"
,
)
args
=
parser
.
parse_args
()
...
...
@@ -79,13 +75,12 @@ def main(args):
with
torch
.
no_grad
():
output
=
wavernn_inference_model
(
mel_specgram
.
to
(
device
),
loss_name
=
args
.
loss
,
mulaw
=
(
not
args
.
no_mulaw
),
batched
=
(
not
args
.
no_batch_inference
),
timesteps
=
args
.
batch_timesteps
,
overlap
=
args
.
batch_overlap
,)
torchaudio
.
save
(
args
.
output_wav_path
,
output
.
reshape
(
1
,
-
1
)
,
sample_rate
=
sample_rate
)
torchaudio
.
save
(
args
.
output_wav_path
,
output
,
sample_rate
=
sample_rate
)
if
__name__
==
"__main__"
:
...
...
examples/pipeline_wavernn/wavernn_inference_wrapper.py
View file @
3bb5feb5
...
...
@@ -21,18 +21,12 @@
# *****************************************************************************
from
typing
import
List
from
torchaudio.models.wavernn
import
WaveRNN
import
torch
import
torch.nn.functional
as
F
import
torchaudio
from
torch
import
Tensor
from
processing
import
(
normalized_waveform_to_bits
,
bits_to_normalized_waveform
,
)
from
processing
import
normalized_waveform_to_bits
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
...
...
@@ -53,12 +47,12 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[h7, h8, h9, h10]]
Args:
x (tensor): Upsampled conditioning channels
with shap
e (1, timesteps, channel).
x (tensor): Upsampled conditioning channels
of siz
e (1, timesteps, channel).
timesteps (int): Timesteps for each index of batch.
overlap (int): Timesteps for both xfade and rnn warmup.
Return:
folded (tensor): folded tensor
with shap
e (n_folds, timesteps + 2 * overlap, channel).
folded (tensor): folded tensor
of siz
e (n_folds, timesteps + 2 * overlap, channel).
'''
_
,
channels
,
total_len
=
x
.
size
()
...
...
@@ -98,15 +92,15 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
Args:
y (Tensor): Batched sequences of audio samples
with shap
e
(num_folds, timesteps + 2 * overlap).
y (Tensor): Batched sequences of audio samples
of siz
e
(num_folds,
channels,
timesteps + 2 * overlap).
overlap (int): Timesteps for both xfade and rnn warmup.
Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor
with shape (
total_len).
unfolded waveform (Tensor) : waveform in a 1d tensor
of size (channels,
total_len).
'''
num_folds
,
length
=
y
.
shape
num_folds
,
channels
,
length
=
y
.
shape
timesteps
=
length
-
2
*
overlap
total_len
=
num_folds
*
(
timesteps
+
overlap
)
+
overlap
...
...
@@ -126,16 +120,16 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
fade_out
=
torch
.
cat
([
linear
,
fade_out
])
# Apply the gain to the overlap samples
y
[:,
:
overlap
]
*=
fade_in
y
[:,
-
overlap
:]
*=
fade_out
y
[:,
:,
:
overlap
]
*=
fade_in
y
[:,
:,
-
overlap
:]
*=
fade_out
unfolded
=
torch
.
zeros
((
total_len
),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
unfolded
=
torch
.
zeros
((
channels
,
total_len
),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
# Loop to add up all the samples
for
i
in
range
(
num_folds
):
start
=
i
*
(
timesteps
+
overlap
)
end
=
start
+
timesteps
+
2
*
overlap
unfolded
[
start
:
end
]
+=
y
[
i
]
unfolded
[
:,
start
:
end
]
+=
y
[
i
]
return
unfolded
...
...
@@ -143,11 +137,11 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
r
"""Pad the given tensor.
Args:
x (Tensor): The tensor to pad
with shap
e (n_batch, n_mels, time).
x (Tensor): The tensor to pad
of siz
e (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
Return:
padded (Tensor): The padded tensor
with shap
e (n_batch, n_mels, time).
padded (Tensor): The padded tensor
of siz
e (n_batch, n_mels, time).
"""
b
,
c
,
t
=
x
.
size
()
total
=
t
+
2
*
pad
if
side
==
'both'
else
t
+
pad
...
...
@@ -163,89 +157,42 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
def
forward
(
self
,
specgram
:
Tensor
,
loss_name
:
str
=
"crossentropy"
,
mulaw
:
bool
=
True
,
batched
:
bool
=
True
,
timesteps
:
int
=
1
10
00
,
overlap
:
int
=
5
50
)
->
Tensor
:
timesteps
:
int
=
100
,
overlap
:
int
=
5
)
->
Tensor
:
r
"""Inference function for WaveRNN.
Based on the implementation from
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py.
Currently only supports multinomial sampling.
Args:
specgram (Tensor): spectrogram with shape (n_mels, n_time)
loss_name (str): The loss function used to train the WaveRNN model.
Available `loss_name` includes `'mol'` and `'crossentropy'`.
specgram (Tensor): spectrogram of size (n_mels, n_time)
mulaw (bool): Whether to perform mulaw decoding (Default: ``True``).
batched (bool): Whether to perform batch prediction. Using batch prediction
will significantly increase the inference speed (Default: ``True``).
timesteps (int): The time steps for each batch. Only used when `batched`
is set to True (Default: ``1
10
00``).
is set to True (Default: ``100``).
overlap (int): The overlapping time steps between batches. Only used when `batched`
is set to True (Default: ``5
50
``).
is set to True (Default: ``5``).
Returns:
waveform (Tensor): Reconstructed waveform with shape (n_time, ).
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
"""
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
self
.
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
specgram
,
aux
=
self
.
wavernn_model
.
upsample
(
specgram
)
if
batched
:
specgram
=
self
.
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
aux
=
self
.
_fold_with_overlap
(
aux
,
timesteps
,
overlap
)
device
=
specgram
.
device
dtype
=
specgram
.
dtype
# make it compatible with torchscript
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
wavernn_model
.
n_classes
))
output
:
List
[
Tensor
]
=
[]
b_size
,
_
,
seq_len
=
specgram
.
size
()
h1
=
torch
.
zeros
((
1
,
b_size
,
self
.
wavernn_model
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
h2
=
torch
.
zeros
((
1
,
b_size
,
self
.
wavernn_model
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
zeros
((
b_size
,
1
),
device
=
device
,
dtype
=
dtype
)
d
=
self
.
wavernn_model
.
n_aux
aux_split
=
[
aux
[:,
d
*
i
:
d
*
(
i
+
1
),
:]
for
i
in
range
(
4
)]
for
i
in
range
(
seq_len
):
m_t
=
specgram
[:,
:,
i
]
a1_t
,
a2_t
,
a3_t
,
a4_t
=
[
a
[:,
:,
i
]
for
a
in
aux_split
]
x
=
torch
.
cat
([
x
,
m_t
,
a1_t
],
dim
=
1
)
x
=
self
.
wavernn_model
.
fc
(
x
)
_
,
h1
=
self
.
wavernn_model
.
rnn1
(
x
.
unsqueeze
(
1
),
h1
)
x
=
x
+
h1
[
0
]
inp
=
torch
.
cat
([
x
,
a2_t
],
dim
=
1
)
_
,
h2
=
self
.
wavernn_model
.
rnn2
(
inp
.
unsqueeze
(
1
),
h2
)
x
=
x
+
h2
[
0
]
x
=
torch
.
cat
([
x
,
a3_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
wavernn_model
.
fc1
(
x
))
x
=
torch
.
cat
([
x
,
a4_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
wavernn_model
.
fc2
(
x
))
logits
=
self
.
wavernn_model
.
fc3
(
x
)
if
loss_name
==
"crossentropy"
:
posterior
=
F
.
softmax
(
logits
,
dim
=
1
)
x
=
torch
.
multinomial
(
posterior
,
1
).
float
()
x
=
bits_to_normalized_waveform
(
x
,
n_bits
)
output
.
append
(
x
.
squeeze
(
-
1
))
else
:
raise
ValueError
(
f
"Unexpected loss_name: '
{
loss_name
}
'. "
f
"Valid choices are 'crossentropy'."
)
output
=
torch
.
stack
(
output
).
transpose
(
0
,
1
).
cpu
()
output
=
self
.
wavernn_model
.
infer
(
specgram
).
cpu
()
if
mulaw
:
output
=
normalized_waveform_to_bits
(
output
,
n_bits
)
...
...
test/torchaudio_unittest/models/models_test.py
View file @
3bb5feb5
...
...
@@ -120,6 +120,31 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
),
n_classes
)
def
test_infer_waveform
(
self
):
"""Validate the output dimensions of a WaveRNN model's infer method.
"""
upsample_scales
=
[
5
,
5
,
8
]
n_rnn
=
512
n_fc
=
512
n_classes
=
512
hop_length
=
200
n_batch
=
2
n_time
=
200
n_freq
=
100
n_output
=
256
n_res_block
=
10
n_hidden
=
128
kernel_size
=
5
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
n_rnn
,
n_fc
,
kernel_size
,
n_freq
,
n_hidden
,
n_output
)
x
=
torch
.
rand
(
n_batch
,
n_freq
,
n_time
)
out
=
model
.
infer
(
x
)
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
))
_ConvTasNetParams
=
namedtuple
(
'_ConvTasNetParams'
,
...
...
torchaudio/models/wavernn.py
View file @
3bb5feb5
...
...
@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Any
import
torch
from
torch
import
Tensor
from
torch
import
nn
import
torch.nn.functional
as
F
from
torch.hub
import
load_state_dict_from_url
...
...
@@ -347,6 +348,70 @@ class WaveRNN(nn.Module):
# bring back channel dimension
return
x
.
unsqueeze
(
1
)
@
torch
.
jit
.
export
def
infer
(
self
,
specgram
:
Tensor
)
->
Tensor
:
r
"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time).
Return:
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time).
1 stands for a single channel.
"""
device
=
specgram
.
device
dtype
=
specgram
.
dtype
# make it compatible with torchscript
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
n_classes
))
specgram
,
aux
=
self
.
upsample
(
specgram
)
output
:
List
[
Tensor
]
=
[]
b_size
,
_
,
seq_len
=
specgram
.
size
()
h1
=
torch
.
zeros
((
1
,
b_size
,
self
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
h2
=
torch
.
zeros
((
1
,
b_size
,
self
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
zeros
((
b_size
,
1
),
device
=
device
,
dtype
=
dtype
)
aux_split
=
[
aux
[:,
self
.
n_aux
*
i
:
self
.
n_aux
*
(
i
+
1
),
:]
for
i
in
range
(
4
)]
for
i
in
range
(
seq_len
):
m_t
=
specgram
[:,
:,
i
]
a1_t
,
a2_t
,
a3_t
,
a4_t
=
[
a
[:,
:,
i
]
for
a
in
aux_split
]
x
=
torch
.
cat
([
x
,
m_t
,
a1_t
],
dim
=
1
)
x
=
self
.
fc
(
x
)
_
,
h1
=
self
.
rnn1
(
x
.
unsqueeze
(
1
),
h1
)
x
=
x
+
h1
[
0
]
inp
=
torch
.
cat
([
x
,
a2_t
],
dim
=
1
)
_
,
h2
=
self
.
rnn2
(
inp
.
unsqueeze
(
1
),
h2
)
x
=
x
+
h2
[
0
]
x
=
torch
.
cat
([
x
,
a3_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
torch
.
cat
([
x
,
a4_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
fc2
(
x
))
logits
=
self
.
fc3
(
x
)
posterior
=
F
.
softmax
(
logits
,
dim
=
1
)
x
=
torch
.
multinomial
(
posterior
,
1
).
float
()
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
x
=
2
*
x
/
(
2
**
n_bits
-
1.0
)
-
1.0
output
.
append
(
x
)
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
)
def
wavernn
(
checkpoint_name
:
str
)
->
WaveRNN
:
r
"""Get pretrained WaveRNN model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment