Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
9d621fd3
"examples/vscode:/vscode.git/clone" did not exist on "e0004f7199598c390b73553e84964467303e15f6"
Unverified
Commit
9d621fd3
authored
May 12, 2021
by
Kirill Ignatev
Committed by
GitHub
May 12, 2021
Browse files
Add autograd tests for TimeMasking/FrequencyMasking (#1498)
parent
1f136671
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
0 deletions
+25
-0
test/torchaudio_unittest/transforms/autograd_test_impl.py
test/torchaudio_unittest/transforms/autograd_test_impl.py
+25
-0
No files found.
test/torchaudio_unittest/transforms/autograd_test_impl.py
View file @
9d621fd3
...
@@ -125,6 +125,31 @@ class AutogradTestMixin(TestBaseMixin):
...
@@ -125,6 +125,31 @@ class AutogradTestMixin(TestBaseMixin):
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
waveform
=
get_whitenoise
(
sample_rate
=
8000
,
duration
=
0.05
,
n_channels
=
2
)
self
.
assert_grad
(
transform
,
[
waveform
],
nondet_tol
=
1e-10
)
self
.
assert_grad
(
transform
,
[
waveform
],
nondet_tol
=
1e-10
)
@
parameterized
.
expand
([(
T
.
TimeMasking
,),
(
T
.
FrequencyMasking
,)])
def
test_masking
(
self
,
masking_transform
):
sample_rate
=
8000
n_fft
=
400
spectrogram
=
get_spectrogram
(
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
2
),
n_fft
=
n_fft
,
power
=
1
)
deterministic_transform
=
_DeterministicWrapper
(
masking_transform
(
400
))
self
.
assert_grad
(
deterministic_transform
,
[
spectrogram
])
@
parameterized
.
expand
([(
T
.
TimeMasking
,),
(
T
.
FrequencyMasking
,)])
def
test_masking_iid
(
self
,
masking_transform
):
sample_rate
=
8000
n_fft
=
400
specs
=
[
get_spectrogram
(
get_whitenoise
(
sample_rate
=
sample_rate
,
duration
=
0.05
,
n_channels
=
2
,
seed
=
i
),
n_fft
=
n_fft
,
power
=
1
)
for
i
in
range
(
3
)
]
batch
=
torch
.
stack
(
specs
)
assert
batch
.
ndim
==
4
deterministic_transform
=
_DeterministicWrapper
(
masking_transform
(
400
,
True
))
self
.
assert_grad
(
deterministic_transform
,
[
batch
])
def
test_spectral_centroid
(
self
):
def
test_spectral_centroid
(
self
):
sample_rate
=
8000
sample_rate
=
8000
transform
=
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
)
transform
=
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment