Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
3bb5feb5
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "bae519ce1b351eb88efe5b248753fd8c59ac6203"
Unverified
Commit
3bb5feb5
authored
Aug 23, 2021
by
yangarbiter
Committed by
GitHub
Aug 23, 2021
Browse files
Refactor WaveRNN infer and move it to the codebase (#1704)
parent
63f0614b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
119 additions
and
85 deletions
+119
-85
docs/source/models.rst
docs/source/models.rst
+2
-0
examples/pipeline_wavernn/inference.py
examples/pipeline_wavernn/inference.py
+3
-8
examples/pipeline_wavernn/wavernn_inference_wrapper.py
examples/pipeline_wavernn/wavernn_inference_wrapper.py
+24
-77
test/torchaudio_unittest/models/models_test.py
test/torchaudio_unittest/models/models_test.py
+25
-0
torchaudio/models/wavernn.py
torchaudio/models/wavernn.py
+65
-0
No files found.
docs/source/models.rst
View file @
3bb5feb5
...
@@ -106,6 +106,8 @@ WaveRNN
...
@@ -106,6 +106,8 @@ WaveRNN
.. automethod:: forward
.. automethod:: forward
.. automethod:: infer
Factory Functions
Factory Functions
-----------------
-----------------
...
...
examples/pipeline_wavernn/inference.py
View file @
3bb5feb5
...
@@ -21,10 +21,6 @@ def parse_args():
...
@@ -21,10 +21,6 @@ def parse_args():
"--jit"
,
default
=
False
,
action
=
"store_true"
,
"--jit"
,
default
=
False
,
action
=
"store_true"
,
help
=
"If used, the model and inference function is jitted."
help
=
"If used, the model and inference function is jitted."
)
)
parser
.
add_argument
(
"--loss"
,
default
=
"crossentropy"
,
choices
=
[
"crossentropy"
],
type
=
str
,
help
=
"The type of loss the pretrained model is trained on."
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--no-batch-inference"
,
default
=
False
,
action
=
"store_true"
,
"--no-batch-inference"
,
default
=
False
,
action
=
"store_true"
,
help
=
"Don't use batch inference."
help
=
"Don't use batch inference."
...
@@ -39,11 +35,11 @@ def parse_args():
...
@@ -39,11 +35,11 @@ def parse_args():
help
=
"Select the WaveRNN checkpoint."
help
=
"Select the WaveRNN checkpoint."
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-timesteps"
,
default
=
1
10
00
,
type
=
int
,
"--batch-timesteps"
,
default
=
100
,
type
=
int
,
help
=
"The time steps for each batch. Only used when batch inference is used"
,
help
=
"The time steps for each batch. Only used when batch inference is used"
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--batch-overlap"
,
default
=
5
50
,
type
=
int
,
"--batch-overlap"
,
default
=
5
,
type
=
int
,
help
=
"The overlapping time steps between batches. Only used when batch inference is used"
,
help
=
"The overlapping time steps between batches. Only used when batch inference is used"
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -79,13 +75,12 @@ def main(args):
...
@@ -79,13 +75,12 @@ def main(args):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
output
=
wavernn_inference_model
(
mel_specgram
.
to
(
device
),
output
=
wavernn_inference_model
(
mel_specgram
.
to
(
device
),
loss_name
=
args
.
loss
,
mulaw
=
(
not
args
.
no_mulaw
),
mulaw
=
(
not
args
.
no_mulaw
),
batched
=
(
not
args
.
no_batch_inference
),
batched
=
(
not
args
.
no_batch_inference
),
timesteps
=
args
.
batch_timesteps
,
timesteps
=
args
.
batch_timesteps
,
overlap
=
args
.
batch_overlap
,)
overlap
=
args
.
batch_overlap
,)
torchaudio
.
save
(
args
.
output_wav_path
,
output
.
reshape
(
1
,
-
1
)
,
sample_rate
=
sample_rate
)
torchaudio
.
save
(
args
.
output_wav_path
,
output
,
sample_rate
=
sample_rate
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
examples/pipeline_wavernn/wavernn_inference_wrapper.py
View file @
3bb5feb5
...
@@ -21,18 +21,12 @@
...
@@ -21,18 +21,12 @@
# *****************************************************************************
# *****************************************************************************
from
typing
import
List
from
torchaudio.models.wavernn
import
WaveRNN
from
torchaudio.models.wavernn
import
WaveRNN
import
torch
import
torch
import
torch.nn.functional
as
F
import
torchaudio
import
torchaudio
from
torch
import
Tensor
from
torch
import
Tensor
from
processing
import
(
from
processing
import
normalized_waveform_to_bits
normalized_waveform_to_bits
,
bits_to_normalized_waveform
,
)
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
class
WaveRNNInferenceWrapper
(
torch
.
nn
.
Module
):
...
@@ -53,12 +47,12 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -53,12 +47,12 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[h7, h8, h9, h10]]
[h7, h8, h9, h10]]
Args:
Args:
x (tensor): Upsampled conditioning channels
with shap
e (1, timesteps, channel).
x (tensor): Upsampled conditioning channels
of siz
e (1, timesteps, channel).
timesteps (int): Timesteps for each index of batch.
timesteps (int): Timesteps for each index of batch.
overlap (int): Timesteps for both xfade and rnn warmup.
overlap (int): Timesteps for both xfade and rnn warmup.
Return:
Return:
folded (tensor): folded tensor
with shap
e (n_folds, timesteps + 2 * overlap, channel).
folded (tensor): folded tensor
of siz
e (n_folds, timesteps + 2 * overlap, channel).
'''
'''
_
,
channels
,
total_len
=
x
.
size
()
_
,
channels
,
total_len
=
x
.
size
()
...
@@ -98,15 +92,15 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -98,15 +92,15 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
[seq1_in, seq1_timesteps, (seq1_out + seq2_in), seq2_timesteps, ...]
Args:
Args:
y (Tensor): Batched sequences of audio samples
with shap
e
y (Tensor): Batched sequences of audio samples
of siz
e
(num_folds, timesteps + 2 * overlap).
(num_folds,
channels,
timesteps + 2 * overlap).
overlap (int): Timesteps for both xfade and rnn warmup.
overlap (int): Timesteps for both xfade and rnn warmup.
Returns:
Returns:
unfolded waveform (Tensor) : waveform in a 1d tensor
with shape (
total_len).
unfolded waveform (Tensor) : waveform in a 1d tensor
of size (channels,
total_len).
'''
'''
num_folds
,
length
=
y
.
shape
num_folds
,
channels
,
length
=
y
.
shape
timesteps
=
length
-
2
*
overlap
timesteps
=
length
-
2
*
overlap
total_len
=
num_folds
*
(
timesteps
+
overlap
)
+
overlap
total_len
=
num_folds
*
(
timesteps
+
overlap
)
+
overlap
...
@@ -126,16 +120,16 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -126,16 +120,16 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
fade_out
=
torch
.
cat
([
linear
,
fade_out
])
fade_out
=
torch
.
cat
([
linear
,
fade_out
])
# Apply the gain to the overlap samples
# Apply the gain to the overlap samples
y
[:,
:
overlap
]
*=
fade_in
y
[:,
:,
:
overlap
]
*=
fade_in
y
[:,
-
overlap
:]
*=
fade_out
y
[:,
:,
-
overlap
:]
*=
fade_out
unfolded
=
torch
.
zeros
((
total_len
),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
unfolded
=
torch
.
zeros
((
channels
,
total_len
),
dtype
=
y
.
dtype
,
device
=
y
.
device
)
# Loop to add up all the samples
# Loop to add up all the samples
for
i
in
range
(
num_folds
):
for
i
in
range
(
num_folds
):
start
=
i
*
(
timesteps
+
overlap
)
start
=
i
*
(
timesteps
+
overlap
)
end
=
start
+
timesteps
+
2
*
overlap
end
=
start
+
timesteps
+
2
*
overlap
unfolded
[
start
:
end
]
+=
y
[
i
]
unfolded
[
:,
start
:
end
]
+=
y
[
i
]
return
unfolded
return
unfolded
...
@@ -143,11 +137,11 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -143,11 +137,11 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
r
"""Pad the given tensor.
r
"""Pad the given tensor.
Args:
Args:
x (Tensor): The tensor to pad
with shap
e (n_batch, n_mels, time).
x (Tensor): The tensor to pad
of siz
e (n_batch, n_mels, time).
pad (int): The amount of padding applied to the input.
pad (int): The amount of padding applied to the input.
Return:
Return:
padded (Tensor): The padded tensor
with shap
e (n_batch, n_mels, time).
padded (Tensor): The padded tensor
of siz
e (n_batch, n_mels, time).
"""
"""
b
,
c
,
t
=
x
.
size
()
b
,
c
,
t
=
x
.
size
()
total
=
t
+
2
*
pad
if
side
==
'both'
else
t
+
pad
total
=
t
+
2
*
pad
if
side
==
'both'
else
t
+
pad
...
@@ -163,89 +157,42 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -163,89 +157,42 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
def
forward
(
self
,
def
forward
(
self
,
specgram
:
Tensor
,
specgram
:
Tensor
,
loss_name
:
str
=
"crossentropy"
,
mulaw
:
bool
=
True
,
mulaw
:
bool
=
True
,
batched
:
bool
=
True
,
batched
:
bool
=
True
,
timesteps
:
int
=
1
10
00
,
timesteps
:
int
=
100
,
overlap
:
int
=
5
50
)
->
Tensor
:
overlap
:
int
=
5
)
->
Tensor
:
r
"""Inference function for WaveRNN.
r
"""Inference function for WaveRNN.
Based on the implementation from
Based on the implementation from
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py.
https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py.
Currently only supports multinomial sampling.
Args:
Args:
specgram (Tensor): spectrogram with shape (n_mels, n_time)
specgram (Tensor): spectrogram of size (n_mels, n_time)
loss_name (str): The loss function used to train the WaveRNN model.
Available `loss_name` includes `'mol'` and `'crossentropy'`.
mulaw (bool): Whether to perform mulaw decoding (Default: ``True``).
mulaw (bool): Whether to perform mulaw decoding (Default: ``True``).
batched (bool): Whether to perform batch prediction. Using batch prediction
batched (bool): Whether to perform batch prediction. Using batch prediction
will significantly increase the inference speed (Default: ``True``).
will significantly increase the inference speed (Default: ``True``).
timesteps (int): The time steps for each batch. Only used when `batched`
timesteps (int): The time steps for each batch. Only used when `batched`
is set to True (Default: ``1
10
00``).
is set to True (Default: ``100``).
overlap (int): The overlapping time steps between batches. Only used when `batched`
overlap (int): The overlapping time steps between batches. Only used when `batched`
is set to True (Default: ``5
50
``).
is set to True (Default: ``5``).
Returns:
Returns:
waveform (Tensor): Reconstructed waveform with shape (n_time, ).
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
"""
"""
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
self
.
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
specgram
=
self
.
_pad_tensor
(
specgram
,
pad
=
pad
,
side
=
'both'
)
specgram
,
aux
=
self
.
wavernn_model
.
upsample
(
specgram
)
if
batched
:
if
batched
:
specgram
=
self
.
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
specgram
=
self
.
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
aux
=
self
.
_fold_with_overlap
(
aux
,
timesteps
,
overlap
)
device
=
specgram
.
device
dtype
=
specgram
.
dtype
# make it compatible with torchscript
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
wavernn_model
.
n_classes
))
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
wavernn_model
.
n_classes
))
output
:
List
[
Tensor
]
=
[]
b_size
,
_
,
seq_len
=
specgram
.
size
()
h1
=
torch
.
zeros
((
1
,
b_size
,
self
.
wavernn_model
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
h2
=
torch
.
zeros
((
1
,
b_size
,
self
.
wavernn_model
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
zeros
((
b_size
,
1
),
device
=
device
,
dtype
=
dtype
)
d
=
self
.
wavernn_model
.
n_aux
aux_split
=
[
aux
[:,
d
*
i
:
d
*
(
i
+
1
),
:]
for
i
in
range
(
4
)]
for
i
in
range
(
seq_len
):
m_t
=
specgram
[:,
:,
i
]
a1_t
,
a2_t
,
a3_t
,
a4_t
=
[
a
[:,
:,
i
]
for
a
in
aux_split
]
x
=
torch
.
cat
([
x
,
m_t
,
a1_t
],
dim
=
1
)
x
=
self
.
wavernn_model
.
fc
(
x
)
_
,
h1
=
self
.
wavernn_model
.
rnn1
(
x
.
unsqueeze
(
1
),
h1
)
x
=
x
+
h1
[
0
]
inp
=
torch
.
cat
([
x
,
a2_t
],
dim
=
1
)
_
,
h2
=
self
.
wavernn_model
.
rnn2
(
inp
.
unsqueeze
(
1
),
h2
)
x
=
x
+
h2
[
0
]
x
=
torch
.
cat
([
x
,
a3_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
wavernn_model
.
fc1
(
x
))
x
=
torch
.
cat
([
x
,
a4_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
wavernn_model
.
fc2
(
x
))
logits
=
self
.
wavernn_model
.
fc3
(
x
)
if
loss_name
==
"crossentropy"
:
posterior
=
F
.
softmax
(
logits
,
dim
=
1
)
x
=
torch
.
multinomial
(
posterior
,
1
).
float
()
x
=
bits_to_normalized_waveform
(
x
,
n_bits
)
output
.
append
(
x
.
squeeze
(
-
1
))
else
:
raise
ValueError
(
f
"Unexpected loss_name: '
{
loss_name
}
'. "
f
"Valid choices are 'crossentropy'."
)
output
=
torch
.
stack
(
output
).
transpose
(
0
,
1
).
cpu
()
output
=
self
.
wavernn_model
.
infer
(
specgram
).
cpu
()
if
mulaw
:
if
mulaw
:
output
=
normalized_waveform_to_bits
(
output
,
n_bits
)
output
=
normalized_waveform_to_bits
(
output
,
n_bits
)
...
...
test/torchaudio_unittest/models/models_test.py
View file @
3bb5feb5
...
@@ -120,6 +120,31 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
...
@@ -120,6 +120,31 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
),
n_classes
)
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
),
n_classes
)
def
test_infer_waveform
(
self
):
"""Validate the output dimensions of a WaveRNN model's infer method.
"""
upsample_scales
=
[
5
,
5
,
8
]
n_rnn
=
512
n_fc
=
512
n_classes
=
512
hop_length
=
200
n_batch
=
2
n_time
=
200
n_freq
=
100
n_output
=
256
n_res_block
=
10
n_hidden
=
128
kernel_size
=
5
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
n_rnn
,
n_fc
,
kernel_size
,
n_freq
,
n_hidden
,
n_output
)
x
=
torch
.
rand
(
n_batch
,
n_freq
,
n_time
)
out
=
model
.
infer
(
x
)
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
))
_ConvTasNetParams
=
namedtuple
(
_ConvTasNetParams
=
namedtuple
(
'_ConvTasNetParams'
,
'_ConvTasNetParams'
,
...
...
torchaudio/models/wavernn.py
View file @
3bb5feb5
...
@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Any
...
@@ -3,6 +3,7 @@ from typing import List, Tuple, Dict, Any
import
torch
import
torch
from
torch
import
Tensor
from
torch
import
Tensor
from
torch
import
nn
from
torch
import
nn
import
torch.nn.functional
as
F
from
torch.hub
import
load_state_dict_from_url
from
torch.hub
import
load_state_dict_from_url
...
@@ -347,6 +348,70 @@ class WaveRNN(nn.Module):
...
@@ -347,6 +348,70 @@ class WaveRNN(nn.Module):
# bring back channel dimension
# bring back channel dimension
return
x
.
unsqueeze
(
1
)
return
x
.
unsqueeze
(
1
)
@
torch
.
jit
.
export
def
infer
(
self
,
specgram
:
Tensor
)
->
Tensor
:
r
"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time).
Return:
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time).
1 stands for a single channel.
"""
device
=
specgram
.
device
dtype
=
specgram
.
dtype
# make it compatible with torchscript
n_bits
=
int
(
torch
.
log2
(
torch
.
ones
(
1
)
*
self
.
n_classes
))
specgram
,
aux
=
self
.
upsample
(
specgram
)
output
:
List
[
Tensor
]
=
[]
b_size
,
_
,
seq_len
=
specgram
.
size
()
h1
=
torch
.
zeros
((
1
,
b_size
,
self
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
h2
=
torch
.
zeros
((
1
,
b_size
,
self
.
n_rnn
),
device
=
device
,
dtype
=
dtype
)
x
=
torch
.
zeros
((
b_size
,
1
),
device
=
device
,
dtype
=
dtype
)
aux_split
=
[
aux
[:,
self
.
n_aux
*
i
:
self
.
n_aux
*
(
i
+
1
),
:]
for
i
in
range
(
4
)]
for
i
in
range
(
seq_len
):
m_t
=
specgram
[:,
:,
i
]
a1_t
,
a2_t
,
a3_t
,
a4_t
=
[
a
[:,
:,
i
]
for
a
in
aux_split
]
x
=
torch
.
cat
([
x
,
m_t
,
a1_t
],
dim
=
1
)
x
=
self
.
fc
(
x
)
_
,
h1
=
self
.
rnn1
(
x
.
unsqueeze
(
1
),
h1
)
x
=
x
+
h1
[
0
]
inp
=
torch
.
cat
([
x
,
a2_t
],
dim
=
1
)
_
,
h2
=
self
.
rnn2
(
inp
.
unsqueeze
(
1
),
h2
)
x
=
x
+
h2
[
0
]
x
=
torch
.
cat
([
x
,
a3_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
torch
.
cat
([
x
,
a4_t
],
dim
=
1
)
x
=
F
.
relu
(
self
.
fc2
(
x
))
logits
=
self
.
fc3
(
x
)
posterior
=
F
.
softmax
(
logits
,
dim
=
1
)
x
=
torch
.
multinomial
(
posterior
,
1
).
float
()
# Transform label [0, 2 ** n_bits - 1] to waveform [-1, 1]
x
=
2
*
x
/
(
2
**
n_bits
-
1.0
)
-
1.0
output
.
append
(
x
)
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
)
def
wavernn
(
checkpoint_name
:
str
)
->
WaveRNN
:
def
wavernn
(
checkpoint_name
:
str
)
->
WaveRNN
:
r
"""Get pretrained WaveRNN model.
r
"""Get pretrained WaveRNN model.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment