Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
7d45851d
Unverified
Commit
7d45851d
authored
May 06, 2021
by
moto
Committed by
GitHub
May 06, 2021
Browse files
Merge test classes for complex (#1491)
parent
ddd2425c
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
52 additions
and
135 deletions
+52
-135
test/torchaudio_unittest/common_utils/case_utils.py
test/torchaudio_unittest/common_utils/case_utils.py
+8
-0
test/torchaudio_unittest/functional/functional_cpu_test.py
test/torchaudio_unittest/functional/functional_cpu_test.py
+1
-13
test/torchaudio_unittest/functional/functional_cuda_test.py
test/torchaudio_unittest/functional/functional_cuda_test.py
+1
-15
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+1
-7
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py
...o_unittest/functional/torchscript_consistency_cpu_test.py
+1
-13
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
..._unittest/functional/torchscript_consistency_cuda_test.py
+1
-15
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+20
-21
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
...o_unittest/transforms/torchscript_consistency_cpu_test.py
+1
-13
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
..._unittest/transforms/torchscript_consistency_cuda_test.py
+1
-15
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+17
-23
No files found.
test/torchaudio_unittest/common_utils/case_utils.py
View file @
7d45851d
...
...
@@ -81,6 +81,14 @@ class TestBaseMixin:
super
().
setUp
()
set_audio_backend
(
self
.
backend
)
@
property
def
complex_dtype
(
self
):
if
self
.
dtype
in
[
'float32'
,
'float'
,
torch
.
float
,
torch
.
float32
]:
return
torch
.
cfloat
if
self
.
dtype
in
[
'float64'
,
'double'
,
torch
.
double
,
torch
.
float64
]:
return
torch
.
cdouble
raise
ValueError
(
f
'No corresponding complex dtype for
{
self
.
dtype
}
'
)
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
pass
...
...
test/torchaudio_unittest/functional/functional_cpu_test.py
View file @
7d45851d
...
...
@@ -4,7 +4,7 @@ import unittest
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
TorchaudioTestCase
,
skipIfNoSox
from
.functional_impl
import
Functional
,
FunctionalComplex
,
FunctionalCPUOnly
from
.functional_impl
import
Functional
,
FunctionalCPUOnly
class
TestFunctionalFloat32
(
Functional
,
FunctionalCPUOnly
,
PytorchTestCase
):
...
...
@@ -21,18 +21,6 @@ class TestFunctionalFloat64(Functional, PytorchTestCase):
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
@
skipIfNoSox
class
TestApplyCodec
(
TorchaudioTestCase
):
backend
=
"sox_io"
...
...
test/torchaudio_unittest/functional/functional_cuda_test.py
View file @
7d45851d
...
...
@@ -2,7 +2,7 @@ import torch
import
unittest
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
.functional_impl
import
Functional
,
FunctionalComplex
from
.functional_impl
import
Functional
@
skipIfNoCuda
...
...
@@ -19,17 +19,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestLFilterFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
7d45851d
...
...
@@ -259,12 +259,6 @@ class Functional(TestBaseMixin):
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
class
FunctionalComplex
(
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
@
nested_params
(
[
0.5
,
1.01
,
1.3
],
[
True
,
False
],
...
...
@@ -286,7 +280,7 @@ class FunctionalComplex(TestBaseMixin):
0
,
np
.
pi
*
hop_length
,
num_freq
,
dtype
=
self
.
real_
dtype
,
device
=
self
.
device
)[...,
None
]
dtype
=
self
.
dtype
,
device
=
self
.
device
)[...,
None
]
spec_stretch
=
F
.
phase_vocoder
(
spec
,
rate
=
rate
,
phase_advance
=
phase_advance
)
...
...
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py
View file @
7d45851d
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Functional
,
FunctionalComplex
from
.torchscript_consistency_impl
import
Functional
class
TestFunctionalFloat32
(
Functional
,
PytorchTestCase
):
...
...
@@ -12,15 +12,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
View file @
7d45851d
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Functional
,
FunctionalComplex
from
.torchscript_consistency_impl
import
Functional
@
skipIfNoCuda
...
...
@@ -14,17 +14,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
7d45851d
...
...
@@ -32,6 +32,25 @@ class Functional(TempDirMixin, TestBaseMixin):
output
=
output
.
shape
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
func
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
torch
.
random
.
manual_seed
(
40
)
output
=
func
(
tensor
)
torch
.
random
.
manual_seed
(
40
)
ts_output
=
ts_func
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_spectrogram
(
self
):
def
func
(
tensor
):
n_fft
=
400
...
...
@@ -572,26 +591,6 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
tensor
)
class
FunctionalComplex
(
TempDirMixin
,
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
def
_assert_consistency
(
self
,
func
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
func
(
tensor
)
ts_output
=
ts_func
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_phase_vocoder
(
self
,
test_paseudo_complex
):
def
func
(
tensor
):
...
...
@@ -610,4 +609,4 @@ class FunctionalComplex(TempDirMixin, TestBaseMixin):
return
F
.
phase_vocoder
(
tensor
,
rate
,
phase_advance
)
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
self
.
_assert_consistency
(
func
,
tensor
,
test_paseudo_complex
)
self
.
_assert_consistency
_complex
(
func
,
tensor
,
test_paseudo_complex
)
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
View file @
7d45851d
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsComplex
from
.torchscript_consistency_impl
import
Transforms
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
...
...
@@ -12,15 +12,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsComplex64
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsComplex128
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
View file @
7d45851d
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsComplex
from
.torchscript_consistency_impl
import
Transforms
@
skipIfNoCuda
...
...
@@ -14,17 +14,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestTransformsComplex64
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestTransformsComplex128
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
7d45851d
...
...
@@ -26,6 +26,22 @@ class Transforms(TempDirMixin, TestBaseMixin):
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'transform.zip'
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
self
.
_assert_consistency
(
T
.
Spectrogram
(),
tensor
)
...
...
@@ -104,35 +120,13 @@ class Transforms(TempDirMixin, TestBaseMixin):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
self
.
_assert_consistency
(
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
),
waveform
)
class
TransformsComplex
(
TempDirMixin
,
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
def
_assert_consistency
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
real_dtype
)
path
=
self
.
get_temp_path
(
'transform.zip'
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_TimeStretch
(
self
,
test_pseudo_complex
):
n_freq
=
400
hop_length
=
512
fixed_rate
=
1.3
tensor
=
torch
.
view_as_complex
(
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
)))
self
.
_assert_consistency
(
self
.
_assert_consistency
_complex
(
T
.
TimeStretch
(
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
),
tensor
,
test_pseudo_complex
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment