Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
137600d0
Commit
137600d0
authored
Oct 13, 2021
by
moto
Browse files
Add `lengths` param to WaveRNN.infer (#1851)
parent
ddc49548
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
28 deletions
+41
-28
examples/pipeline_wavernn/wavernn_inference_wrapper.py
examples/pipeline_wavernn/wavernn_inference_wrapper.py
+0
-3
test/torchaudio_unittest/models/models_test.py
test/torchaudio_unittest/models/models_test.py
+21
-18
torchaudio/models/wavernn.py
torchaudio/models/wavernn.py
+20
-7
No files found.
examples/pipeline_wavernn/wavernn_inference_wrapper.py
View file @
137600d0
...
...
@@ -163,10 +163,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
"""
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
torch
.
nn
.
functional
.
pad
(
specgram
,
(
pad
,
pad
))
if
batched
:
specgram
=
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
...
...
test/torchaudio_unittest/models/models_test.py
View file @
137600d0
...
...
@@ -126,40 +126,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
"""
upsample_scales
=
[
5
,
5
,
8
]
n_rnn
=
5
12
n_fc
=
5
12
n_classes
=
5
12
n_rnn
=
12
8
n_fc
=
12
8
n_classes
=
12
8
hop_length
=
200
n_batch
=
2
n_time
=
20
0
n_freq
=
100
n_output
=
25
6
n_res_block
=
10
n_hidden
=
128
n_time
=
5
0
n_freq
=
25
n_output
=
6
4
n_res_block
=
2
n_hidden
=
32
kernel_size
=
5
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
n_rnn
,
n_fc
,
kernel_size
,
n_freq
,
n_hidden
,
n_output
)
x
=
torch
.
rand
(
n_batch
,
n_freq
,
n_time
)
out
=
model
.
infer
(
x
)
lengths
=
torch
.
tensor
([
n_time
,
n_time
//
2
])
out
,
waveform_lengths
=
model
.
infer
(
x
,
lengths
)
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
))
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
n_time
)
assert
waveform_lengths
[
0
]
==
hop_length
*
n_time
assert
waveform_lengths
[
1
]
==
hop_length
*
n_time
//
2
def
test_torchscript_infer
(
self
):
"""Scripted model outputs the same as eager mode"""
upsample_scales
=
[
5
,
5
,
8
]
n_rnn
=
5
12
n_fc
=
5
12
n_classes
=
5
12
n_rnn
=
12
8
n_fc
=
12
8
n_classes
=
12
8
hop_length
=
200
n_batch
=
2
n_time
=
20
0
n_freq
=
100
n_output
=
25
6
n_res_block
=
10
n_hidden
=
128
n_time
=
5
0
n_freq
=
25
n_output
=
6
4
n_res_block
=
2
n_hidden
=
32
kernel_size
=
5
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
...
...
torchaudio/models/wavernn.py
View file @
137600d0
from
typing
import
List
,
Tuple
,
Dict
,
Any
from
typing
import
List
,
Tuple
,
Dict
,
Any
,
Optional
import
math
import
torch
...
...
@@ -182,6 +182,7 @@ class UpsampleNetwork(nn.Module):
total_scale
=
1
for
upsample_scale
in
upsample_scales
:
total_scale
*=
upsample_scale
self
.
total_scale
:
int
=
total_scale
self
.
indent
=
(
kernel_size
-
1
)
//
2
*
total_scale
self
.
resnet
=
MelResNet
(
n_res_block
,
n_freq
,
n_hidden
,
n_output
,
kernel_size
)
...
...
@@ -265,6 +266,7 @@ class WaveRNN(nn.Module):
super
().
__init__
()
self
.
kernel_size
=
kernel_size
self
.
_pad
=
(
kernel_size
-
1
if
kernel_size
%
2
else
kernel_size
)
//
2
self
.
n_rnn
=
n_rnn
self
.
n_aux
=
n_output
//
4
self
.
hop_length
=
hop_length
...
...
@@ -351,24 +353,35 @@ class WaveRNN(nn.Module):
return
x
.
unsqueeze
(
1
)
@
torch
.
jit
.
export
def
infer
(
self
,
specgram
:
Tensor
)
->
Tensor
:
def
infer
(
self
,
specgram
:
Tensor
,
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]
:
r
"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time).
Return:
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time).
specgram (Tensor):
Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
lengths (Tensor or None, optional):
Indicates the valid length in of each spectrogram in time axis.
Shape: `(n_batch, )`.
Returns:
Tensor and optional Tensor:
Tensor
The inferred waveform of size `(n_batch, 1, n_time)`.
1 stands for a single channel.
Tensor or None
The valid lengths of each waveform in the batch. Size `(n_batch, )`.
"""
device
=
specgram
.
device
dtype
=
specgram
.
dtype
specgram
=
torch
.
nn
.
functional
.
pad
(
specgram
,
(
self
.
_pad
,
self
.
_pad
))
specgram
,
aux
=
self
.
upsample
(
specgram
)
if
lengths
is
not
None
:
lengths
=
lengths
*
self
.
upsample
.
total_scale
output
:
List
[
Tensor
]
=
[]
b_size
,
_
,
seq_len
=
specgram
.
size
()
...
...
@@ -410,7 +423,7 @@ class WaveRNN(nn.Module):
output
.
append
(
x
)
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
)
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
)
,
lengths
def
wavernn
(
checkpoint_name
:
str
)
->
WaveRNN
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment