Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
137600d0
Commit
137600d0
authored
Oct 13, 2021
by
moto
Browse files
Add `lengths` param to WaveRNN.infer (#1851)
parent
ddc49548
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
28 deletions
+41
-28
examples/pipeline_wavernn/wavernn_inference_wrapper.py
examples/pipeline_wavernn/wavernn_inference_wrapper.py
+0
-3
test/torchaudio_unittest/models/models_test.py
test/torchaudio_unittest/models/models_test.py
+21
-18
torchaudio/models/wavernn.py
torchaudio/models/wavernn.py
+20
-7
No files found.
examples/pipeline_wavernn/wavernn_inference_wrapper.py
View file @
137600d0
...
@@ -163,10 +163,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
...
@@ -163,10 +163,7 @@ class WaveRNNInferenceWrapper(torch.nn.Module):
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
waveform (Tensor): Reconstructed waveform of size (1, n_time, ).
1 represents single channel.
1 represents single channel.
"""
"""
pad
=
(
self
.
wavernn_model
.
kernel_size
-
1
)
//
2
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
specgram
.
unsqueeze
(
0
)
specgram
=
torch
.
nn
.
functional
.
pad
(
specgram
,
(
pad
,
pad
))
if
batched
:
if
batched
:
specgram
=
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
specgram
=
_fold_with_overlap
(
specgram
,
timesteps
,
overlap
)
...
...
test/torchaudio_unittest/models/models_test.py
View file @
137600d0
...
@@ -126,40 +126,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
...
@@ -126,40 +126,43 @@ class TestWaveRNN(common_utils.TorchaudioTestCase):
"""
"""
upsample_scales
=
[
5
,
5
,
8
]
upsample_scales
=
[
5
,
5
,
8
]
n_rnn
=
5
12
n_rnn
=
12
8
n_fc
=
5
12
n_fc
=
12
8
n_classes
=
5
12
n_classes
=
12
8
hop_length
=
200
hop_length
=
200
n_batch
=
2
n_batch
=
2
n_time
=
20
0
n_time
=
5
0
n_freq
=
100
n_freq
=
25
n_output
=
25
6
n_output
=
6
4
n_res_block
=
10
n_res_block
=
2
n_hidden
=
128
n_hidden
=
32
kernel_size
=
5
kernel_size
=
5
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
n_rnn
,
n_fc
,
kernel_size
,
n_freq
,
n_hidden
,
n_output
)
n_rnn
,
n_fc
,
kernel_size
,
n_freq
,
n_hidden
,
n_output
)
x
=
torch
.
rand
(
n_batch
,
n_freq
,
n_time
)
x
=
torch
.
rand
(
n_batch
,
n_freq
,
n_time
)
out
=
model
.
infer
(
x
)
lengths
=
torch
.
tensor
([
n_time
,
n_time
//
2
])
out
,
waveform_lengths
=
model
.
infer
(
x
,
lengths
)
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
(
n_time
-
kernel_size
+
1
))
assert
out
.
size
()
==
(
n_batch
,
1
,
hop_length
*
n_time
)
assert
waveform_lengths
[
0
]
==
hop_length
*
n_time
assert
waveform_lengths
[
1
]
==
hop_length
*
n_time
//
2
def
test_torchscript_infer
(
self
):
def
test_torchscript_infer
(
self
):
"""Scripted model outputs the same as eager mode"""
"""Scripted model outputs the same as eager mode"""
upsample_scales
=
[
5
,
5
,
8
]
upsample_scales
=
[
5
,
5
,
8
]
n_rnn
=
5
12
n_rnn
=
12
8
n_fc
=
5
12
n_fc
=
12
8
n_classes
=
5
12
n_classes
=
12
8
hop_length
=
200
hop_length
=
200
n_batch
=
2
n_batch
=
2
n_time
=
20
0
n_time
=
5
0
n_freq
=
100
n_freq
=
25
n_output
=
25
6
n_output
=
6
4
n_res_block
=
10
n_res_block
=
2
n_hidden
=
128
n_hidden
=
32
kernel_size
=
5
kernel_size
=
5
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
model
=
WaveRNN
(
upsample_scales
,
n_classes
,
hop_length
,
n_res_block
,
...
...
torchaudio/models/wavernn.py
View file @
137600d0
from
typing
import
List
,
Tuple
,
Dict
,
Any
from
typing
import
List
,
Tuple
,
Dict
,
Any
,
Optional
import
math
import
math
import
torch
import
torch
...
@@ -182,6 +182,7 @@ class UpsampleNetwork(nn.Module):
...
@@ -182,6 +182,7 @@ class UpsampleNetwork(nn.Module):
total_scale
=
1
total_scale
=
1
for
upsample_scale
in
upsample_scales
:
for
upsample_scale
in
upsample_scales
:
total_scale
*=
upsample_scale
total_scale
*=
upsample_scale
self
.
total_scale
:
int
=
total_scale
self
.
indent
=
(
kernel_size
-
1
)
//
2
*
total_scale
self
.
indent
=
(
kernel_size
-
1
)
//
2
*
total_scale
self
.
resnet
=
MelResNet
(
n_res_block
,
n_freq
,
n_hidden
,
n_output
,
kernel_size
)
self
.
resnet
=
MelResNet
(
n_res_block
,
n_freq
,
n_hidden
,
n_output
,
kernel_size
)
...
@@ -265,6 +266,7 @@ class WaveRNN(nn.Module):
...
@@ -265,6 +266,7 @@ class WaveRNN(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
kernel_size
=
kernel_size
self
.
kernel_size
=
kernel_size
self
.
_pad
=
(
kernel_size
-
1
if
kernel_size
%
2
else
kernel_size
)
//
2
self
.
n_rnn
=
n_rnn
self
.
n_rnn
=
n_rnn
self
.
n_aux
=
n_output
//
4
self
.
n_aux
=
n_output
//
4
self
.
hop_length
=
hop_length
self
.
hop_length
=
hop_length
...
@@ -351,24 +353,35 @@ class WaveRNN(nn.Module):
...
@@ -351,24 +353,35 @@ class WaveRNN(nn.Module):
return
x
.
unsqueeze
(
1
)
return
x
.
unsqueeze
(
1
)
@
torch
.
jit
.
export
@
torch
.
jit
.
export
def
infer
(
self
,
specgram
:
Tensor
)
->
Tensor
:
def
infer
(
self
,
specgram
:
Tensor
,
lengths
:
Optional
[
Tensor
]
=
None
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]
:
r
"""Inference method of WaveRNN.
r
"""Inference method of WaveRNN.
This function currently only supports multinomial sampling, which assumes the
This function currently only supports multinomial sampling, which assumes the
network is trained on cross entropy loss.
network is trained on cross entropy loss.
Args:
Args:
specgram (Tensor): The input spectrogram to the WaveRNN of size (n_batch, n_freq, n_time).
specgram (Tensor):
Batch of spectrograms. Shape: `(n_batch, n_freq, n_time)`.
Return:
lengths (Tensor or None, optional):
waveform (Tensor): The inferred waveform of size (n_batch, 1, n_time).
Indicates the valid length in of each spectrogram in time axis.
Shape: `(n_batch, )`.
Returns:
Tensor and optional Tensor:
Tensor
The inferred waveform of size `(n_batch, 1, n_time)`.
1 stands for a single channel.
1 stands for a single channel.
Tensor or None
The valid lengths of each waveform in the batch. Size `(n_batch, )`.
"""
"""
device
=
specgram
.
device
device
=
specgram
.
device
dtype
=
specgram
.
dtype
dtype
=
specgram
.
dtype
specgram
=
torch
.
nn
.
functional
.
pad
(
specgram
,
(
self
.
_pad
,
self
.
_pad
))
specgram
,
aux
=
self
.
upsample
(
specgram
)
specgram
,
aux
=
self
.
upsample
(
specgram
)
if
lengths
is
not
None
:
lengths
=
lengths
*
self
.
upsample
.
total_scale
output
:
List
[
Tensor
]
=
[]
output
:
List
[
Tensor
]
=
[]
b_size
,
_
,
seq_len
=
specgram
.
size
()
b_size
,
_
,
seq_len
=
specgram
.
size
()
...
@@ -410,7 +423,7 @@ class WaveRNN(nn.Module):
...
@@ -410,7 +423,7 @@ class WaveRNN(nn.Module):
output
.
append
(
x
)
output
.
append
(
x
)
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
)
return
torch
.
stack
(
output
).
permute
(
1
,
2
,
0
)
,
lengths
def
wavernn
(
checkpoint_name
:
str
)
->
WaveRNN
:
def
wavernn
(
checkpoint_name
:
str
)
->
WaveRNN
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment