Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
2c28b743
"...text-generation-inference.git" did not exist on "8ad20daf33617296aa3982f1db08bc8abaf40f83"
Unverified
Commit
2c28b743
authored
May 14, 2020
by
moto
Committed by
GitHub
May 14, 2020
Browse files
Adopt PyTorch's test util to torchscript test (#640)
parent
995b75f8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
41 deletions
+36
-41
test/common_utils.py
test/common_utils.py
+7
-4
test/test_functional.py
test/test_functional.py
+1
-1
test/test_torchscript_consistency.py
test/test_torchscript_consistency.py
+28
-36
No files found.
test/common_utils.py
View file @
2c28b743
import
os
import
os
import
tempfile
import
tempfile
import
unittest
from
typing
import
Type
,
Iterable
from
typing
import
Type
,
Iterable
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
shutil
import
copytree
from
shutil
import
copytree
import
torch
import
torch
from
torch.testing._internal.common_utils
import
TestCase
import
torchaudio
import
torchaudio
import
pytest
_TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
_TEST_DIR_PATH
=
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
))
BACKENDS
=
torchaudio
.
_backend
.
_audio_backends
BACKENDS
=
torchaudio
.
_backend
.
_audio_backends
...
@@ -87,6 +88,9 @@ class TestBaseMixin:
...
@@ -87,6 +88,9 @@ class TestBaseMixin:
device
=
None
device
=
None
_SKIP_IF_NO_CUDA
=
unittest
.
skipIf
(
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)
def
define_test_suite
(
testbase
:
Type
[
TestBaseMixin
],
dtype
:
str
,
device
:
str
):
def
define_test_suite
(
testbase
:
Type
[
TestBaseMixin
],
dtype
:
str
,
device
:
str
):
if
dtype
not
in
[
'float32'
,
'float64'
]:
if
dtype
not
in
[
'float32'
,
'float64'
]:
raise
NotImplementedError
(
f
'Unexpected dtype:
{
dtype
}
'
)
raise
NotImplementedError
(
f
'Unexpected dtype:
{
dtype
}
'
)
...
@@ -95,11 +99,10 @@ def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
...
@@ -95,11 +99,10 @@ def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
name
=
f
'Test
{
testbase
.
__name__
}
_
{
device
.
upper
()
}
_
{
dtype
.
capitalize
()
}
'
name
=
f
'Test
{
testbase
.
__name__
}
_
{
device
.
upper
()
}
_
{
dtype
.
capitalize
()
}
'
attrs
=
{
'dtype'
:
getattr
(
torch
,
dtype
),
'device'
:
torch
.
device
(
device
)}
attrs
=
{
'dtype'
:
getattr
(
torch
,
dtype
),
'device'
:
torch
.
device
(
device
)}
testsuite
=
type
(
name
,
(
testbase
,),
attrs
)
testsuite
=
type
(
name
,
(
testbase
,
TestCase
),
attrs
)
if
device
==
'cuda'
:
if
device
==
'cuda'
:
testsuite
=
pytest
.
mark
.
skipif
(
testsuite
=
_SKIP_IF_NO_CUDA
(
testsuite
)
not
torch
.
cuda
.
is_available
(),
reason
=
'CUDA not available'
)(
testsuite
)
return
testsuite
return
testsuite
...
...
test/test_functional.py
View file @
2c28b743
...
@@ -23,7 +23,7 @@ class Lfilter(common_utils.TestBaseMixin):
...
@@ -23,7 +23,7 @@ class Lfilter(common_utils.TestBaseMixin):
a_coeffs
=
torch
.
tensor
([
1
,
0
,
0
,
0
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
a_coeffs
=
torch
.
tensor
([
1
,
0
,
0
,
0
],
dtype
=
self
.
dtype
,
device
=
self
.
device
)
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
)
output_waveform
=
F
.
lfilter
(
waveform
,
a_coeffs
,
b_coeffs
)
torch
.
testing
.
assert_allclose
(
output_waveform
[:,
3
:],
waveform
[:,
0
:
-
3
],
atol
=
1e-5
,
rtol
=
1e-5
)
self
.
assertEqual
(
output_waveform
[:,
3
:],
waveform
[:,
0
:
-
3
],
atol
=
1e-5
,
rtol
=
1e-5
)
def
test_clamp
(
self
):
def
test_clamp
(
self
):
input_signal
=
torch
.
ones
(
1
,
44100
*
1
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
input_signal
=
torch
.
ones
(
1
,
44100
*
1
,
dtype
=
self
.
dtype
,
device
=
self
.
device
)
...
...
test/test_torchscript_consistency.py
View file @
2c28b743
"""Test suites for jit-ability and its numerical compatibility"""
"""Test suites for jit-ability and its numerical compatibility"""
import
unittest
import
unittest
import
pytest
import
torch
import
torch
import
torchaudio
import
torchaudio
...
@@ -10,29 +9,18 @@ import torchaudio.transforms as T
...
@@ -10,29 +9,18 @@ import torchaudio.transforms as T
import
common_utils
import
common_utils
def
_assert_functional_consistency
(
func
,
tensor
,
shape_only
=
False
):
ts_func
=
torch
.
jit
.
script
(
func
)
output
=
func
(
tensor
)
ts_output
=
ts_func
(
tensor
)
if
shape_only
:
assert
ts_output
.
shape
==
output
.
shape
,
(
ts_output
.
shape
,
output
.
shape
)
else
:
torch
.
testing
.
assert_allclose
(
ts_output
,
output
)
def
_assert_transforms_consistency
(
transform
,
tensor
):
ts_transform
=
torch
.
jit
.
script
(
transform
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
torch
.
testing
.
assert_allclose
(
ts_output
,
output
)
class
Functional
(
common_utils
.
TestBaseMixin
):
class
Functional
(
common_utils
.
TestBaseMixin
):
"""Implements test for `functinoal` modul that are performed for different devices"""
"""Implements test for `functinoal` modul that are performed for different devices"""
def
_assert_consistency
(
self
,
func
,
tensor
,
shape_only
=
False
):
def
_assert_consistency
(
self
,
func
,
tensor
,
shape_only
=
False
):
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
return
_assert_functional_consistency
(
func
,
tensor
,
shape_only
=
shape_only
)
ts_func
=
torch
.
jit
.
script
(
func
)
output
=
func
(
tensor
)
ts_output
=
ts_func
(
tensor
)
if
shape_only
:
ts_output
=
ts_output
.
shape
output
=
output
.
shape
self
.
assertEqual
(
ts_output
,
output
)
def
test_spectrogram
(
self
):
def
test_spectrogram
(
self
):
def
func
(
tensor
):
def
func
(
tensor
):
...
@@ -210,7 +198,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -210,7 +198,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_lfilter
(
self
):
def
test_lfilter
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -254,7 +242,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -254,7 +242,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_lowpass
(
self
):
def
test_lowpass
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -268,7 +256,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -268,7 +256,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_highpass
(
self
):
def
test_highpass
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -282,7 +270,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -282,7 +270,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_allpass
(
self
):
def
test_allpass
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
filepath
=
common_utils
.
get_asset_path
(
'whitenoise.wav'
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -297,7 +285,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -297,7 +285,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_bandpass_with_csg
(
self
):
def
test_bandpass_with_csg
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -313,7 +301,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -313,7 +301,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_bandpass_without_csg
(
self
):
def
test_bandpass_without_csg
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -329,7 +317,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -329,7 +317,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_bandreject
(
self
):
def
test_bandreject
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -344,7 +332,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -344,7 +332,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_band_with_noise
(
self
):
def
test_band_with_noise
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -360,7 +348,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -360,7 +348,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_band_without_noise
(
self
):
def
test_band_without_noise
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -376,7 +364,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -376,7 +364,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_treble
(
self
):
def
test_treble
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -392,7 +380,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -392,7 +380,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_deemph
(
self
):
def
test_deemph
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -405,7 +393,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -405,7 +393,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_riaa
(
self
):
def
test_riaa
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -418,7 +406,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -418,7 +406,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_equalizer
(
self
):
def
test_equalizer
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -434,7 +422,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -434,7 +422,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_perf_biquad_filtering
(
self
):
def
test_perf_biquad_filtering
(
self
):
if
self
.
dtype
==
torch
.
float64
:
if
self
.
dtype
==
torch
.
float64
:
pytest
.
xfail
(
"This test is known to fail for float64"
)
raise
unittest
.
SkipTest
(
"This test is known to fail for float64"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
...
@@ -515,7 +503,7 @@ class Functional(common_utils.TestBaseMixin):
...
@@ -515,7 +503,7 @@ class Functional(common_utils.TestBaseMixin):
def
test_phaser
(
self
):
def
test_phaser
(
self
):
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
filepath
=
common_utils
.
get_asset_path
(
"whitenoise.wav"
)
waveform
,
sample_rate
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
waveform
,
_
=
torchaudio
.
load
(
filepath
,
normalization
=
True
)
def
func
(
tensor
):
def
func
(
tensor
):
gain_in
=
0.5
gain_in
=
0.5
...
@@ -534,7 +522,11 @@ class Transforms(common_utils.TestBaseMixin):
...
@@ -534,7 +522,11 @@ class Transforms(common_utils.TestBaseMixin):
def
_assert_consistency
(
self
,
transform
,
tensor
):
def
_assert_consistency
(
self
,
transform
,
tensor
):
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
_assert_transforms_consistency
(
transform
,
tensor
)
ts_transform
=
torch
.
jit
.
script
(
transform
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_Spectrogram
(
self
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
tensor
=
torch
.
rand
((
1
,
1000
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment