Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2c28b743
Unverified
Commit
2c28b743
authored
May 14, 2020
by
moto
Committed by
GitHub
May 14, 2020
Browse files
Adopt PyTorch's test util to torchscript test (#640)
parent
995b75f8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
41 deletions
+36
-41
test/common_utils.py
test/common_utils.py
+7
-4
test/test_functional.py
test/test_functional.py
+1
-1
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+28
-36
No files found.
test/common_utils.py
View file @
2c28b743
import
os
import
os
import
tempfile
import
tempfile
import
unittest
from
typing
import
Type
,
Iterable
from
typing
import
Type
,
Iterable
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
shutil
import
copytree
from
shutil
import
copytree
import
torch
import
torch
from
torch.testing._internal.common_utils
import
TestCase
import
torchaudio
import
torchaudio
import
pytest
_TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
_TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
BACKENDS
=
torchaudio
.
_backend
.
_audio_backends
BACKENDS
=
torchaudio
.
_backend
.
_audio_backends
...
@@ -87,6 +88,9 @@ class TestBaseMixin:
...
@@ -87,6 +88,9 @@ class TestBaseMixin:
device
=
None
device
=
None
_SKIP_IF_NO_CUDA
=
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
def
define_test_suite
(
testbase
:
Type
[
TestBaseMixin
],
dtype
:
str
,
device
:
str
):
def
define_test_suite
(
testbase
:
Type
[
TestBaseMixin
],
dtype
:
str
,
device
:
str
):
if
dtype
not
in
[
'float32'
,
'float64'
]:
if
dtype
not
in
[
'float32'
,
'float64'
]:
raise
NotImplementedError
(
f
'Unexpected dtype:
{
dtype
}
'
)
raise
NotImplementedError
(
f
'Unexpected dtype:
{
dtype
}
'
)
...
@@ -95,11 +99,10 @@ def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
...
@@ -95,11 +99,10 @@ def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
name
=
f
'Test
{
testbase
.
__name__
}
_
{
device
.
upper
()
}
_
{
dtype
.
capitalize
()
}
'
name
=
f
'Test
{
testbase
.
__name__
}
_
{
device
.
upper
()
}
_
{
dtype
.
capitalize
()
}
'
attrs
=
{
'dtype'
:
getattr
(
torch
,
dtype
),
'device'
:
torch
.
device
(
device
)}
attrs
=
{
'dtype'
:
getattr
(
torch
,
dtype
),
'device'
:
torch
.
device
(
device
)}
testsuite
=
type
(
name
,
(
testbase
,),
attrs
)
testsuite
=
type
(
name
,
(
testbase
,
TestCase
),
attrs
)
if
device
==
'cuda'
:
if
device
==
'cuda'
:
testsuite
=
pytest
.
mark
.
skipif
(
testsuite
=
_SKIP_IF_NO_CUDA
(
testsuite
)
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)(
testsuite
)
return
testsuite
return
testsuite
...
...
test/test_functional.py
View file @
2c28b743
...
@@ -23,7 +23,7 @@ class Lfilter(common_utils.TestBaseMixin):
...
@@ -23,7 +23,7 @@ class Lfilter(common_utils.TestBaseMixin):
a_coeffs
=
torch
.
tensor
([
1
,
0
,
0
,
0
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a_coeffs
=
torch
.
tensor
([
1
,
0
,
0
,
0
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
)
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
)
torch
.
testing
.
assert_allclose
(
output_waveform
[:,
3
:],
waveform
[:,
0
:
-
3
],
atol
=
1e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
output_waveform
[:,
3
:],
waveform
[:,
0
:
-
3
],
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_clamp
(
self
):
def
test_clamp
(
self
):
input_signal
=
torch
.
ones
(
1
,
44100
*
1
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
input_signal
=
torch
.
ones
(
1
,
44100
*
1
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
...
...
test/test_torchscript_consistency.py
View file @
2c28b743
"""Test suites for jit-ability and its numerical compatibility"""
"""Test suites for jit-ability and its numerical compatibility"""
import
unittest
import
unittest
import
pytest
import
torch
import
torch
import
torchaudio
import
torchaudio
...
@@ -10,29 +9,18 @@ import torchaudio.transforms as T
...
@@ -10,29 +9,18 @@ import torchaudio.transforms as T
import
common_utils
import
common_utils
def
_assert_functional_consistency
(
func
,
tensor
,
shape_only
=
False
):
ts_func
=
torch
.
jit
.
script
(
func
)
output
=
func
(
tensor
)
ts_output
=
ts_func
(
tensor
)
if
shape_only
:
assert
ts_output
.
shape
==
output
.
shape
,
(
ts_output
.
shape
,
output
.
shape
)
else
:
torch
.
testing
.
assert_allclose
(
ts_output
,
output
)
def
_assert_transforms_consistency
(
transform
,
tensor
):
ts_transform
=
torch
.
jit
.
script
(
transform
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
torch
.
testing
.
assert_allclose
(
ts_output
,
output
)
class
Functional
(
common_utils
.
TestBaseMixin
):
class
Functional
(
common_utils
.
TestBaseMixin
):
"""Implements test for `functinoal` modul that are performed for different devices"""
"""Implements test for `functinoal` modul that are performed for different devices"""
def
_assert_consistency
(
self
,
func
,
tensor
,
shape_only
=
False
):
def
_assert_consistency
(
self
,
func
,
tensor
,
shape_only
=
False
):
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
_assert_functional_consistency
(
func
,
tensor
,
shape_only
=
shape_only
)
ts_func
=
torch
.
jit
.
script
(
func
)
output
=
func
(
tensor
)
ts_output
=
ts_func
(
tensor
)
if
shape_only
:
ts_output
=
ts_output
.
shape
output
=
output
.
shape
self
.
assertEqual
(
ts_output
,
output
)
def
test_spectrogram
(
self
):
def
test_spectrogram
(
self
):
def
func
(
tensor
):
def
func
(
tensor
):
...
@@ -210,7 +198,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -210,7 +198,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_lfilter
(
self
):
def
test_lfilter
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -254,7 +242,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -254,7 +242,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_lowpass
(
self
):
def
test_lowpass
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -268,7 +256,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -268,7 +256,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_highpass
(
self
):
def
test_highpass
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -282,7 +270,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -282,7 +270,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_allpass
(
self
):
def
test_allpass
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -297,7 +285,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -297,7 +285,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_bandpass_with_csg
(
self
):
def
test_bandpass_with_csg
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -313,7 +301,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -313,7 +301,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_bandpass_without_csg
(
self
):
def
test_bandpass_without_csg
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -329,7 +317,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -329,7 +317,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_bandreject
(
self
):
def
test_bandreject
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -344,7 +332,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -344,7 +332,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_band_with_noise
(
self
):
def
test_band_with_noise
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -360,7 +348,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -360,7 +348,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_band_without_noise
(
self
):
def
test_band_without_noise
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -376,7 +364,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -376,7 +364,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_treble
(
self
):
def
test_treble
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -392,7 +380,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -392,7 +380,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_deemph
(
self
):
def
test_deemph
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -405,7 +393,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -405,7 +393,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_riaa
(
self
):
def
test_riaa
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -418,7 +406,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -418,7 +406,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_equalizer
(
self
):
def
test_equalizer
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -434,7 +422,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -434,7 +422,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_perf_biquad_filtering
(
self
):
def
test_perf_biquad_filtering
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -515,7 +503,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -515,7 +503,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_phaser
(
self
):
def
test_phaser
(
self
):
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
def
func
(
tensor
):
def
func
(
tensor
):
gain_in
=
0.5
gain_in
=
0.5
...
@@ -534,7 +522,11 @@ class Transforms(common_utils.TestBaseMixin):
...
@@ -534,7 +522,11 @@ class Transforms(common_utils.TestBaseMixin):
def
_assert_consistency
(
self
,
transform
,
tensor
):
def
_assert_consistency
(
self
,
transform
,
tensor
):
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
_assert_transforms_consistency
(
transform
,
tensor
)
ts_transform
=
torch
.
jit
.
script
(
transform
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_Spectrogram
(
self
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
tensor
=
torch
.
rand
((
1
,
1000
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment