Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b5b6d30f
Unverified
Commit
b5b6d30f
authored
Apr 09, 2020
by
moto
Committed by
GitHub
Apr 09, 2020
Browse files
Separate CPU and GPU tests for Transforms torchscript test (#520)
parent
c29598d5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
48 deletions
+44
-48
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+44
-48
No files found.
test/test_torchscript_consistency.py
View file @
b5b6d30f
...
...
@@ -5,11 +5,20 @@ import unittest
import
torch
import
torchaudio
import
torchaudio.functional
as
F
import
torchaudio.transforms
import
torchaudio.transforms
as
T
import
common_utils
def
_assert_transforms_consistency
(
transform
,
tensor
,
device
):
tensor
=
tensor
.
to
(
device
)
transform
=
transform
.
to
(
device
)
ts_transform
=
torch
.
jit
.
script
(
transform
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
torch
.
testing
.
assert_allclose
(
ts_output
,
output
)
def
_assert_functional_consistency
(
py_method
,
*
args
,
shape_only
=
False
,
**
kwargs
):
jit_method
=
torch
.
jit
.
script
(
py_method
)
...
...
@@ -301,85 +310,64 @@ class TestFunctional(unittest.TestCase):
_assert_functional_consistency
(
F
.
lfilter
,
waveform
,
a
,
b
)
RUN_CUDA
=
torch
.
cuda
.
is_available
()
print
(
"Run test with cuda:"
,
RUN_CUDA
)
def
_test_script_module
(
f
,
tensor
,
*
args
,
**
kwargs
):
py_method
=
f
(
*
args
,
**
kwargs
)
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
torch
.
testing
.
assert_allclose
(
jit_out
,
py_out
)
if
RUN_CUDA
:
tensor
=
tensor
.
to
(
"cuda"
)
py_method
=
py_method
.
cuda
()
jit_method
=
torch
.
jit
.
script
(
py_method
)
class
_TransformsTestMixin
:
"""Implements test for Transforms that are performed for different devices"""
device
=
None
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
def
_assert_consistency
(
self
,
transform
,
tensor
)
:
_assert_transforms_consistency
(
transform
,
tensor
,
self
.
device
)
torch
.
testing
.
assert_allclose
(
jit_out
,
py_out
)
class
TestTransforms
(
unittest
.
TestCase
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
Spectrogram
,
tensor
)
self
.
_assert_consistency
(
T
.
Spectrogram
()
,
tensor
)
def
test_GriffinLim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_script_module
(
torchaudio
.
transforms
.
GriffinLim
,
tensor
,
length
=
1000
,
rand_init
=
False
)
self
.
_assert_consistency
(
T
.
GriffinLim
(
length
=
1000
,
rand_init
=
False
)
,
tensor
)
def
test_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
AmplitudeToDB
,
spec
)
self
.
_assert_consistency
(
T
.
AmplitudeToDB
()
,
spec
)
def
test_MelScale
(
self
):
spec_f
=
torch
.
rand
((
1
,
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
MelScale
,
spec_f
)
self
.
_assert_consistency
(
T
.
MelScale
()
,
spec_f
)
def
test_MelSpectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MelSpectrogram
,
tensor
)
self
.
_assert_consistency
(
T
.
MelSpectrogram
()
,
tensor
)
def
test_MFCC
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MFCC
,
tensor
)
self
.
_assert_consistency
(
T
.
MFCC
()
,
tensor
)
def
test_Resample
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
sample_rate
=
100.
sample_rate_2
=
50.
_test_script_module
(
torchaudio
.
transforms
.
Resample
,
tensor
,
sample_rate
,
sample_rate_2
)
self
.
_assert_consistency
(
T
.
Resample
(
sample_rate
,
sample_rate_2
),
tensor
)
def
test_ComplexNorm
(
self
):
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
ComplexNorm
,
tensor
)
self
.
_assert_consistency
(
T
.
ComplexNorm
()
,
tensor
)
def
test_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawEncoding
,
tensor
)
self
.
_assert_consistency
(
T
.
MuLawEncoding
()
,
tensor
)
def
test_MuLawDecoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawDecoding
,
tensor
)
self
.
_assert_consistency
(
T
.
MuLawDecoding
()
,
tensor
)
def
test_TimeStretch
(
self
):
n_freq
=
400
hop_length
=
512
fixed_rate
=
1.3
tensor
=
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
TimeStretch
,
tensor
,
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
)
self
.
_assert_consistency
(
T
.
TimeStretch
(
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
),
tensor
,
)
def
test_Fade
(
self
):
test_filepath
=
os
.
path
.
join
(
...
...
@@ -387,24 +375,32 @@ class TestTransforms(unittest.TestCase):
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
fade_in_len
=
3000
fade_out_len
=
3000
_test_script_module
(
torchaudio
.
transforms
.
Fade
,
waveform
,
fade_in_len
,
fade_out_len
)
self
.
_assert_consistency
(
T
.
Fade
(
fade_in_len
,
fade_out_len
),
waveform
)
def
test_FrequencyMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
FrequencyMasking
,
tensor
,
freq_mask_param
=
60
,
iid_masks
=
False
)
self
.
_assert_consistency
(
T
.
FrequencyMasking
(
freq_mask_param
=
60
,
iid_masks
=
False
),
tensor
)
def
test_TimeMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
TimeMasking
,
tensor
,
time_mask_param
=
30
,
iid_masks
=
False
)
self
.
_assert_consistency
(
T
.
TimeMasking
(
time_mask_param
=
30
,
iid_masks
=
False
),
tensor
)
def
test_Vol
(
self
):
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
_test_script_module
(
torchaudio
.
transforms
.
Vol
,
waveform
,
1.1
)
self
.
_assert_consistency
(
T
.
Vol
(
1.1
),
waveform
)
class
TestTransformsCPU
(
_TransformsTestMixin
,
unittest
.
TestCase
):
"""Test suite for Transforms module on CPU"""
device
=
torch
.
device
(
'cpu'
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
'CUDA not available'
)
class
TestTransformsCUDA
(
_TransformsTestMixin
,
unittest
.
TestCase
):
"""Test suite for Transforms module on GPU"""
device
=
torch
.
device
(
'cuda'
)
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment