Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
ad534c11
Unverified
Commit
ad534c11
authored
Apr 15, 2021
by
moto
Committed by
GitHub
Apr 15, 2021
Browse files
Add autograd test to T.TimeStretch (and F.phase_vocoder) (#1420)
parent
5c696b50
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
2 deletions
+54
-2
test/torchaudio_unittest/transforms/autograd_test_impl.py
test/torchaudio_unittest/transforms/autograd_test_impl.py
+54
-2
No files found.
test/torchaudio_unittest/transforms/autograd_test_impl.py
View file @
ad534c11
from
typing
import
List
import
unittest
from
parameterized
import
parameterized
import
torch
...
...
@@ -35,10 +36,16 @@ class AutogradTestMixin(TestBaseMixin):
):
transform
=
transform
.
to
(
dtype
=
torch
.
float64
,
device
=
self
.
device
)
# gradcheck and gradgradcheck only pass if the input tensors are of dtype `torch.double` or
# `torch.cdouble`, when the default eps and tolerance values are used.
inputs_
=
[]
for
i
in
inputs
:
i
.
requires_grad
=
True
inputs_
.
append
(
i
.
to
(
dtype
=
torch
.
float64
,
device
=
self
.
device
))
if
torch
.
is_tensor
(
i
):
i
=
i
.
to
(
dtype
=
torch
.
cdouble
if
i
.
is_complex
()
else
torch
.
double
,
device
=
self
.
device
)
i
.
requires_grad
=
True
inputs_
.
append
(
i
)
assert
gradcheck
(
transform
,
inputs_
)
assert
gradgradcheck
(
transform
,
inputs_
,
nondet_tol
=
nondet_tol
)
...
...
@@ -129,3 +136,48 @@ class AutogradTestMixin(TestBaseMixin):
transform
=
T
.
AmplitudeToDB
()
waveform
=
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
2
)
self
.
assert_grad
(
transform
,
[
waveform
])
@
unittest
.
expectedFailure
def
test_timestretch_zeros_fail
(
self
):
"""Test that ``T.TimeStretch`` fails gradcheck at 0
This is because ``F.phase_vocoder`` converts data from cartesian to polar coordinate,
which performs ``atan2(img, real)``, and gradient is not defined at 0.
"""
n_fft
=
16
transform
=
T
.
TimeStretch
(
n_freq
=
n_fft
//
2
+
1
,
fixed_rate
=
0.99
)
waveform
=
torch
.
zeros
(
2
,
40
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
n_fft
,
power
=
None
)
self
.
assert_grad
(
transform
,
[
spectrogram
])
@
nested_params
(
[
0.7
,
0.8
,
0.9
,
1.0
,
1.3
],
[
False
,
True
],
)
def
test_timestretch_non_zero
(
self
,
rate
,
test_pseudo_complex
):
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
for cases where input is not zero.
As tested above, when spectrogram contains values close to zero, the gradients are unstable
and gradcheck fails.
In this test, we generate spectrogram from random signal, then we push the points around
zero away from the origin.
This process does not reflect the real use-case, and it is not practical for users, but
this helps us understand to what degree the function is differentiable and when not.
"""
n_fft
=
16
transform
=
T
.
TimeStretch
(
n_freq
=
n_fft
//
2
+
1
,
fixed_rate
=
rate
)
waveform
=
get_whitenoise
(
sample_rate
=
40
,
duration
=
1
,
n_channels
=
2
)
spectrogram
=
get_spectrogram
(
waveform
,
n_fft
=
n_fft
,
power
=
None
)
# 1e-3 is too small (on CPU)
epsilon
=
1e-2
too_close
=
spectrogram
.
abs
()
<
epsilon
spectrogram
[
too_close
]
=
epsilon
*
spectrogram
[
too_close
]
/
spectrogram
[
too_close
].
abs
()
if
test_pseudo_complex
:
spectrogram
=
torch
.
view_as_real
(
spectrogram
)
self
.
assert_grad
(
transform
,
[
spectrogram
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment