Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Torchaudio
Commits
7d45851d
Unverified
Commit
7d45851d
authored
May 06, 2021
by
moto
Committed by
GitHub
May 06, 2021
Browse files
Merge test classes for complex (#1491)
parent
ddd2425c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
52 additions
and
135 deletions
+52
-135
test/torchaudio_unittest/common_utils/case_utils.py
test/torchaudio_unittest/common_utils/case_utils.py
+8
-0
test/torchaudio_unittest/functional/functional_cpu_test.py
test/torchaudio_unittest/functional/functional_cpu_test.py
+1
-13
test/torchaudio_unittest/functional/functional_cuda_test.py
test/torchaudio_unittest/functional/functional_cuda_test.py
+1
-15
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+1
-7
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py
...o_unittest/functional/torchscript_consistency_cpu_test.py
+1
-13
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
..._unittest/functional/torchscript_consistency_cuda_test.py
+1
-15
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+20
-21
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
...o_unittest/transforms/torchscript_consistency_cpu_test.py
+1
-13
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
..._unittest/transforms/torchscript_consistency_cuda_test.py
+1
-15
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+17
-23
No files found.
test/torchaudio_unittest/common_utils/case_utils.py
View file @
7d45851d
...
@@ -81,6 +81,14 @@ class TestBaseMixin:
...
@@ -81,6 +81,14 @@ class TestBaseMixin:
super
().
setUp
()
super
().
setUp
()
set_audio_backend
(
self
.
backend
)
set_audio_backend
(
self
.
backend
)
@
property
def
complex_dtype
(
self
):
if
self
.
dtype
in
[
'float32'
,
'float'
,
torch
.
float
,
torch
.
float32
]:
return
torch
.
cfloat
if
self
.
dtype
in
[
'float64'
,
'double'
,
torch
.
double
,
torch
.
float64
]:
return
torch
.
cdouble
raise
ValueError
(
f
'No corresponding complex dtype for
{
self
.
dtype
}
'
)
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
pass
pass
...
...
test/torchaudio_unittest/functional/functional_cpu_test.py
View file @
7d45851d
...
@@ -4,7 +4,7 @@ import unittest
...
@@ -4,7 +4,7 @@ import unittest
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
TorchaudioTestCase
,
skipIfNoSox
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
TorchaudioTestCase
,
skipIfNoSox
from
.functional_impl
import
Functional
,
FunctionalComplex
,
FunctionalCPUOnly
from
.functional_impl
import
Functional
,
FunctionalCPUOnly
class
TestFunctionalFloat32
(
Functional
,
FunctionalCPUOnly
,
PytorchTestCase
):
class
TestFunctionalFloat32
(
Functional
,
FunctionalCPUOnly
,
PytorchTestCase
):
...
@@ -21,18 +21,6 @@ class TestFunctionalFloat64(Functional, PytorchTestCase):
...
@@ -21,18 +21,6 @@ class TestFunctionalFloat64(Functional, PytorchTestCase):
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
@
skipIfNoSox
@
skipIfNoSox
class
TestApplyCodec
(
TorchaudioTestCase
):
class
TestApplyCodec
(
TorchaudioTestCase
):
backend
=
"sox_io"
backend
=
"sox_io"
...
...
test/torchaudio_unittest/functional/functional_cuda_test.py
View file @
7d45851d
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
unittest
import
unittest
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
.functional_impl
import
Functional
,
FunctionalComplex
from
.functional_impl
import
Functional
@
skipIfNoCuda
@
skipIfNoCuda
...
@@ -19,17 +19,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
...
@@ -19,17 +19,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestLFilterFloat64
(
Functional
,
PytorchTestCase
):
class
TestLFilterFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
7d45851d
...
@@ -259,12 +259,6 @@ class Functional(TestBaseMixin):
...
@@ -259,12 +259,6 @@ class Functional(TestBaseMixin):
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
class
FunctionalComplex
(
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
@
nested_params
(
@
nested_params
(
[
0.5
,
1.01
,
1.3
],
[
0.5
,
1.01
,
1.3
],
[
True
,
False
],
[
True
,
False
],
...
@@ -286,7 +280,7 @@ class FunctionalComplex(TestBaseMixin):
...
@@ -286,7 +280,7 @@ class FunctionalComplex(TestBaseMixin):
0
,
0
,
np
.
pi
*
hop_length
,
np
.
pi
*
hop_length
,
num_freq
,
num_freq
,
dtype
=
self
.
real_
dtype
,
device
=
self
.
device
)[...,
None
]
dtype
=
self
.
dtype
,
device
=
self
.
device
)[...,
None
]
spec_stretch
=
F
.
phase_vocoder
(
spec
,
rate
=
rate
,
phase_advance
=
phase_advance
)
spec_stretch
=
F
.
phase_vocoder
(
spec
,
rate
=
rate
,
phase_advance
=
phase_advance
)
...
...
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Functional
,
FunctionalComplex
from
.torchscript_consistency_impl
import
Functional
class
TestFunctionalFloat32
(
Functional
,
PytorchTestCase
):
class
TestFunctionalFloat32
(
Functional
,
PytorchTestCase
):
...
@@ -12,15 +12,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
...
@@ -12,15 +12,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Functional
,
FunctionalComplex
from
.torchscript_consistency_impl
import
Functional
@
skipIfNoCuda
@
skipIfNoCuda
...
@@ -14,17 +14,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
...
@@ -14,17 +14,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
7d45851d
...
@@ -32,6 +32,25 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -32,6 +32,25 @@ class Functional(TempDirMixin, TestBaseMixin):
output
=
output
.
shape
output
=
output
.
shape
self
.
assertEqual
(
ts_output
,
output
)
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
func
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
torch
.
random
.
manual_seed
(
40
)
output
=
func
(
tensor
)
torch
.
random
.
manual_seed
(
40
)
ts_output
=
ts_func
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_spectrogram
(
self
):
def
test_spectrogram
(
self
):
def
func
(
tensor
):
def
func
(
tensor
):
n_fft
=
400
n_fft
=
400
...
@@ -572,26 +591,6 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -572,26 +591,6 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
tensor
)
self
.
_assert_consistency
(
func
,
tensor
)
class
FunctionalComplex
(
TempDirMixin
,
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
def
_assert_consistency
(
self
,
func
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
func
(
tensor
)
ts_output
=
ts_func
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_phase_vocoder
(
self
,
test_paseudo_complex
):
def
test_phase_vocoder
(
self
,
test_paseudo_complex
):
def
func
(
tensor
):
def
func
(
tensor
):
...
@@ -610,4 +609,4 @@ class FunctionalComplex(TempDirMixin, TestBaseMixin):
...
@@ -610,4 +609,4 @@ class FunctionalComplex(TempDirMixin, TestBaseMixin):
return
F
.
phase_vocoder
(
tensor
,
rate
,
phase_advance
)
return
F
.
phase_vocoder
(
tensor
,
rate
,
phase_advance
)
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
self
.
_assert_consistency
(
func
,
tensor
,
test_paseudo_complex
)
self
.
_assert_consistency
_complex
(
func
,
tensor
,
test_paseudo_complex
)
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsComplex
from
.torchscript_consistency_impl
import
Transforms
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
...
@@ -12,15 +12,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
...
@@ -12,15 +12,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsComplex64
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsComplex128
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsComplex
from
.torchscript_consistency_impl
import
Transforms
@
skipIfNoCuda
@
skipIfNoCuda
...
@@ -14,17 +14,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
...
@@ -14,17 +14,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestTransformsComplex64
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestTransformsComplex128
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
7d45851d
...
@@ -26,6 +26,22 @@ class Transforms(TempDirMixin, TestBaseMixin):
...
@@ -26,6 +26,22 @@ class Transforms(TempDirMixin, TestBaseMixin):
ts_output
=
ts_transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'transform.zip'
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_Spectrogram
(
self
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
tensor
=
torch
.
rand
((
1
,
1000
))
self
.
_assert_consistency
(
T
.
Spectrogram
(),
tensor
)
self
.
_assert_consistency
(
T
.
Spectrogram
(),
tensor
)
...
@@ -104,35 +120,13 @@ class Transforms(TempDirMixin, TestBaseMixin):
...
@@ -104,35 +120,13 @@ class Transforms(TempDirMixin, TestBaseMixin):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
self
.
_assert_consistency
(
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
),
waveform
)
self
.
_assert_consistency
(
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
),
waveform
)
class
TransformsComplex
(
TempDirMixin
,
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
def
_assert_consistency
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
real_dtype
)
path
=
self
.
get_temp_path
(
'transform.zip'
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_TimeStretch
(
self
,
test_pseudo_complex
):
def
test_TimeStretch
(
self
,
test_pseudo_complex
):
n_freq
=
400
n_freq
=
400
hop_length
=
512
hop_length
=
512
fixed_rate
=
1.3
fixed_rate
=
1.3
tensor
=
torch
.
view_as_complex
(
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
)))
tensor
=
torch
.
view_as_complex
(
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
)))
self
.
_assert_consistency
(
self
.
_assert_consistency
_complex
(
T
.
TimeStretch
(
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
),
T
.
TimeStretch
(
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
),
tensor
,
tensor
,
test_pseudo_complex
test_pseudo_complex
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment