Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
b5b6d30f
Unverified
Commit
b5b6d30f
authored
Apr 09, 2020
by
moto
Committed by
GitHub
Apr 09, 2020
Browse files
Separate CPU and GPU tests for Transforms torchscript test (#520)
parent
c29598d5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
48 deletions
+44
-48
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+44
-48
No files found.
test/test_torchscript_consistency.py
View file @
b5b6d30f
...
@@ -5,11 +5,20 @@ import unittest
...
@@ -5,11 +5,20 @@ import unittest
import
torch
import
torch
import
torchaudio
import
torchaudio
import
torchaudio.functional
as
F
import
torchaudio.functional
as
F
import
torchaudio.transforms
import
torchaudio.transforms
as
T
import
common_utils
import
common_utils
def
_assert_transforms_consistency
(
transform
,
tensor
,
device
):
tensor
=
tensor
.
to
(
device
)
transform
=
transform
.
to
(
device
)
ts_transform
=
torch
.
jit
.
script
(
transform
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
torch
.
testing
.
assert_allclose
(
ts_output
,
output
)
def
_assert_functional_consistency
(
py_method
,
*
args
,
shape_only
=
False
,
**
kwargs
):
def
_assert_functional_consistency
(
py_method
,
*
args
,
shape_only
=
False
,
**
kwargs
):
jit_method
=
torch
.
jit
.
script
(
py_method
)
jit_method
=
torch
.
jit
.
script
(
py_method
)
...
@@ -301,85 +310,64 @@ class TestFunctional(unittest.TestCase):
...
@@ -301,85 +310,64 @@ class TestFunctional(unittest.TestCase):
_assert_functional_consistency
(
F
.
lfilter
,
waveform
,
a
,
b
)
_assert_functional_consistency
(
F
.
lfilter
,
waveform
,
a
,
b
)
RUN_CUDA
=
torch
.
cuda
.
is_available
()
class
_TransformsTestMixin
:
print
(
"Run test with cuda:"
,
RUN_CUDA
)
"""Implements test for Transforms that are performed for different devices"""
device
=
None
def
_test_script_module
(
f
,
tensor
,
*
args
,
**
kwargs
):
py_method
=
f
(
*
args
,
**
kwargs
)
jit_method
=
torch
.
jit
.
script
(
py_method
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
torch
.
testing
.
assert_allclose
(
jit_out
,
py_out
)
if
RUN_CUDA
:
tensor
=
tensor
.
to
(
"cuda"
)
py_method
=
py_method
.
cuda
()
def
_assert_consistency
(
self
,
transform
,
tensor
):
jit_method
=
torch
.
jit
.
script
(
py_method
)
_assert_transforms_consistency
(
transform
,
tensor
,
self
.
device
)
py_out
=
py_method
(
tensor
)
jit_out
=
jit_method
(
tensor
)
torch
.
testing
.
assert_allclose
(
jit_out
,
py_out
)
class
TestTransforms
(
unittest
.
TestCase
):
def
test_Spectrogram
(
self
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
Spectrogram
,
tensor
)
self
.
_assert_consistency
(
T
.
Spectrogram
()
,
tensor
)
def
test_GriffinLim
(
self
):
def
test_GriffinLim
(
self
):
tensor
=
torch
.
rand
((
1
,
201
,
6
))
tensor
=
torch
.
rand
((
1
,
201
,
6
))
_test_script_module
(
torchaudio
.
transforms
.
GriffinLim
,
tensor
,
length
=
1000
,
rand_init
=
False
)
self
.
_assert_consistency
(
T
.
GriffinLim
(
length
=
1000
,
rand_init
=
False
)
,
tensor
)
def
test_AmplitudeToDB
(
self
):
def
test_AmplitudeToDB
(
self
):
spec
=
torch
.
rand
((
6
,
201
))
spec
=
torch
.
rand
((
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
AmplitudeToDB
,
spec
)
self
.
_assert_consistency
(
T
.
AmplitudeToDB
()
,
spec
)
def
test_MelScale
(
self
):
def
test_MelScale
(
self
):
spec_f
=
torch
.
rand
((
1
,
6
,
201
))
spec_f
=
torch
.
rand
((
1
,
6
,
201
))
_test_script_module
(
torchaudio
.
transforms
.
MelScale
,
spec_f
)
self
.
_assert_consistency
(
T
.
MelScale
()
,
spec_f
)
def
test_MelSpectrogram
(
self
):
def
test_MelSpectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MelSpectrogram
,
tensor
)
self
.
_assert_consistency
(
T
.
MelSpectrogram
()
,
tensor
)
def
test_MFCC
(
self
):
def
test_MFCC
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
tensor
=
torch
.
rand
((
1
,
1000
))
_test_script_module
(
torchaudio
.
transforms
.
MFCC
,
tensor
)
self
.
_assert_consistency
(
T
.
MFCC
()
,
tensor
)
def
test_Resample
(
self
):
def
test_Resample
(
self
):
tensor
=
torch
.
rand
((
2
,
1000
))
tensor
=
torch
.
rand
((
2
,
1000
))
sample_rate
=
100.
sample_rate
=
100.
sample_rate_2
=
50.
sample_rate_2
=
50.
self
.
_assert_consistency
(
T
.
Resample
(
sample_rate
,
sample_rate_2
),
tensor
)
_test_script_module
(
torchaudio
.
transforms
.
Resample
,
tensor
,
sample_rate
,
sample_rate_2
)
def
test_ComplexNorm
(
self
):
def
test_ComplexNorm
(
self
):
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
tensor
=
torch
.
rand
((
1
,
2
,
201
,
2
))
_test_script_module
(
torchaudio
.
transforms
.
ComplexNorm
,
tensor
)
self
.
_assert_consistency
(
T
.
ComplexNorm
()
,
tensor
)
def
test_MuLawEncoding
(
self
):
def
test_MuLawEncoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawEncoding
,
tensor
)
self
.
_assert_consistency
(
T
.
MuLawEncoding
()
,
tensor
)
def
test_MuLawDecoding
(
self
):
def
test_MuLawDecoding
(
self
):
tensor
=
torch
.
rand
((
1
,
10
))
tensor
=
torch
.
rand
((
1
,
10
))
_test_script_module
(
torchaudio
.
transforms
.
MuLawDecoding
,
tensor
)
self
.
_assert_consistency
(
T
.
MuLawDecoding
()
,
tensor
)
def
test_TimeStretch
(
self
):
def
test_TimeStretch
(
self
):
n_freq
=
400
n_freq
=
400
hop_length
=
512
hop_length
=
512
fixed_rate
=
1.3
fixed_rate
=
1.3
tensor
=
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
))
tensor
=
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
))
_test_script_module
(
self
.
_assert_consistency
(
torchaudio
.
transforms
.
TimeStretch
,
T
.
TimeStretch
(
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
),
tensor
,
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
)
tensor
,
)
def
test_Fade
(
self
):
def
test_Fade
(
self
):
test_filepath
=
os
.
path
.
join
(
test_filepath
=
os
.
path
.
join
(
...
@@ -387,24 +375,32 @@ class TestTransforms(unittest.TestCase):
...
@@ -387,24 +375,32 @@ class TestTransforms(unittest.TestCase):
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
fade_in_len
=
3000
fade_in_len
=
3000
fade_out_len
=
3000
fade_out_len
=
3000
self
.
_assert_consistency
(
T
.
Fade
(
fade_in_len
,
fade_out_len
),
waveform
)
_test_script_module
(
torchaudio
.
transforms
.
Fade
,
waveform
,
fade_in_len
,
fade_out_len
)
def
test_FrequencyMasking
(
self
):
def
test_FrequencyMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
self
.
_assert_consistency
(
T
.
FrequencyMasking
(
freq_mask_param
=
60
,
iid_masks
=
False
),
tensor
)
torchaudio
.
transforms
.
FrequencyMasking
,
tensor
,
freq_mask_param
=
60
,
iid_masks
=
False
)
def
test_TimeMasking
(
self
):
def
test_TimeMasking
(
self
):
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
tensor
=
torch
.
rand
((
10
,
2
,
50
,
10
,
2
))
_test_script_module
(
self
.
_assert_consistency
(
T
.
TimeMasking
(
time_mask_param
=
30
,
iid_masks
=
False
),
tensor
)
torchaudio
.
transforms
.
TimeMasking
,
tensor
,
time_mask_param
=
30
,
iid_masks
=
False
)
def
test_Vol
(
self
):
def
test_Vol
(
self
):
test_filepath
=
os
.
path
.
join
(
test_filepath
=
os
.
path
.
join
(
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
common_utils
.
TEST_DIR_PATH
,
'assets'
,
'steam-train-whistle-daniel_simon.wav'
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
waveform
,
_
=
torchaudio
.
load
(
test_filepath
)
_test_script_module
(
torchaudio
.
transforms
.
Vol
,
waveform
,
1.1
)
self
.
_assert_consistency
(
T
.
Vol
(
1.1
),
waveform
)
class
TestTransformsCPU
(
_TransformsTestMixin
,
unittest
.
TestCase
):
"""Test suite for Transforms module on CPU"""
device
=
torch
.
device
(
'cpu'
)
@
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
'CUDA not available'
)
class
TestTransformsCUDA
(
_TransformsTestMixin
,
unittest
.
TestCase
):
"""Test suite for Transforms module on GPU"""
device
=
torch
.
device
(
'cuda'
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment