Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
hehl2
Torchaudio
Commits
7d45851d
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "cbd571f28bd9dc78cc4f3668c9eeac5b820f5f09"
Unverified
Commit
7d45851d
authored
May 06, 2021
by
moto
Committed by
GitHub
May 06, 2021
Browse files
Merge test classes for complex (#1491)
parent
ddd2425c
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
52 additions
and
135 deletions
+52
-135
test/torchaudio_unittest/common_utils/case_utils.py
test/torchaudio_unittest/common_utils/case_utils.py
+8
-0
test/torchaudio_unittest/functional/functional_cpu_test.py
test/torchaudio_unittest/functional/functional_cpu_test.py
+1
-13
test/torchaudio_unittest/functional/functional_cuda_test.py
test/torchaudio_unittest/functional/functional_cuda_test.py
+1
-15
test/torchaudio_unittest/functional/functional_impl.py
test/torchaudio_unittest/functional/functional_impl.py
+1
-7
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py
...o_unittest/functional/torchscript_consistency_cpu_test.py
+1
-13
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
..._unittest/functional/torchscript_consistency_cuda_test.py
+1
-15
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
...audio_unittest/functional/torchscript_consistency_impl.py
+20
-21
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
...o_unittest/transforms/torchscript_consistency_cpu_test.py
+1
-13
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
..._unittest/transforms/torchscript_consistency_cuda_test.py
+1
-15
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
...audio_unittest/transforms/torchscript_consistency_impl.py
+17
-23
No files found.
test/torchaudio_unittest/common_utils/case_utils.py
View file @
7d45851d
...
@@ -81,6 +81,14 @@ class TestBaseMixin:
...
@@ -81,6 +81,14 @@ class TestBaseMixin:
super
().
setUp
()
super
().
setUp
()
set_audio_backend
(
self
.
backend
)
set_audio_backend
(
self
.
backend
)
@
property
def
complex_dtype
(
self
):
if
self
.
dtype
in
[
'float32'
,
'float'
,
torch
.
float
,
torch
.
float32
]:
return
torch
.
cfloat
if
self
.
dtype
in
[
'float64'
,
'double'
,
torch
.
double
,
torch
.
float64
]:
return
torch
.
cdouble
raise
ValueError
(
f
'No corresponding complex dtype for
{
self
.
dtype
}
'
)
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
class
TorchaudioTestCase
(
TestBaseMixin
,
PytorchTestCase
):
pass
pass
...
...
test/torchaudio_unittest/functional/functional_cpu_test.py
View file @
7d45851d
...
@@ -4,7 +4,7 @@ import unittest
...
@@ -4,7 +4,7 @@ import unittest
from
parameterized
import
parameterized
from
parameterized
import
parameterized
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
TorchaudioTestCase
,
skipIfNoSox
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
TorchaudioTestCase
,
skipIfNoSox
from
.functional_impl
import
Functional
,
FunctionalComplex
,
FunctionalCPUOnly
from
.functional_impl
import
Functional
,
FunctionalCPUOnly
class
TestFunctionalFloat32
(
Functional
,
FunctionalCPUOnly
,
PytorchTestCase
):
class
TestFunctionalFloat32
(
Functional
,
FunctionalCPUOnly
,
PytorchTestCase
):
...
@@ -21,18 +21,6 @@ class TestFunctionalFloat64(Functional, PytorchTestCase):
...
@@ -21,18 +21,6 @@ class TestFunctionalFloat64(Functional, PytorchTestCase):
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
@
skipIfNoSox
@
skipIfNoSox
class
TestApplyCodec
(
TorchaudioTestCase
):
class
TestApplyCodec
(
TorchaudioTestCase
):
backend
=
"sox_io"
backend
=
"sox_io"
...
...
test/torchaudio_unittest/functional/functional_cuda_test.py
View file @
7d45851d
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
import
unittest
import
unittest
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
torchaudio_unittest.common_utils
import
PytorchTestCase
,
skipIfNoCuda
from
.functional_impl
import
Functional
,
FunctionalComplex
from
.functional_impl
import
Functional
@
skipIfNoCuda
@
skipIfNoCuda
...
@@ -19,17 +19,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
...
@@ -19,17 +19,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestLFilterFloat64
(
Functional
,
PytorchTestCase
):
class
TestLFilterFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/functional_impl.py
View file @
7d45851d
...
@@ -259,12 +259,6 @@ class Functional(TestBaseMixin):
...
@@ -259,12 +259,6 @@ class Functional(TestBaseMixin):
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
self
.
assertEqual
(
specgrams
,
specgrams_copy
)
class
FunctionalComplex
(
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
@
nested_params
(
@
nested_params
(
[
0.5
,
1.01
,
1.3
],
[
0.5
,
1.01
,
1.3
],
[
True
,
False
],
[
True
,
False
],
...
@@ -286,7 +280,7 @@ class FunctionalComplex(TestBaseMixin):
...
@@ -286,7 +280,7 @@ class FunctionalComplex(TestBaseMixin):
0
,
0
,
np
.
pi
*
hop_length
,
np
.
pi
*
hop_length
,
num_freq
,
num_freq
,
dtype
=
self
.
real_
dtype
,
device
=
self
.
device
)[...,
None
]
dtype
=
self
.
dtype
,
device
=
self
.
device
)[...,
None
]
spec_stretch
=
F
.
phase_vocoder
(
spec
,
rate
=
rate
,
phase_advance
=
phase_advance
)
spec_stretch
=
F
.
phase_vocoder
(
spec
,
rate
=
rate
,
phase_advance
=
phase_advance
)
...
...
test/torchaudio_unittest/functional/torchscript_consistency_cpu_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Functional
,
FunctionalComplex
from
.torchscript_consistency_impl
import
Functional
class
TestFunctionalFloat32
(
Functional
,
PytorchTestCase
):
class
TestFunctionalFloat32
(
Functional
,
PytorchTestCase
):
...
@@ -12,15 +12,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
...
@@ -12,15 +12,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/functional/torchscript_consistency_cuda_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Functional
,
FunctionalComplex
from
.torchscript_consistency_impl
import
Functional
@
skipIfNoCuda
@
skipIfNoCuda
...
@@ -14,17 +14,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
...
@@ -14,17 +14,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
class
TestFunctionalFloat64
(
Functional
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex64
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestFunctionalComplex128
(
FunctionalComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/functional/torchscript_consistency_impl.py
View file @
7d45851d
...
@@ -32,6 +32,25 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -32,6 +32,25 @@ class Functional(TempDirMixin, TestBaseMixin):
output
=
output
.
shape
output
=
output
.
shape
self
.
assertEqual
(
ts_output
,
output
)
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
func
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
torch
.
random
.
manual_seed
(
40
)
output
=
func
(
tensor
)
torch
.
random
.
manual_seed
(
40
)
ts_output
=
ts_func
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_spectrogram
(
self
):
def
test_spectrogram
(
self
):
def
func
(
tensor
):
def
func
(
tensor
):
n_fft
=
400
n_fft
=
400
...
@@ -572,26 +591,6 @@ class Functional(TempDirMixin, TestBaseMixin):
...
@@ -572,26 +591,6 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
tensor
=
common_utils
.
get_whitenoise
(
sample_rate
=
44100
)
self
.
_assert_consistency
(
func
,
tensor
)
self
.
_assert_consistency
(
func
,
tensor
)
class
FunctionalComplex
(
TempDirMixin
,
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
def
_assert_consistency
(
self
,
func
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
path
=
self
.
get_temp_path
(
'func.zip'
)
torch
.
jit
.
script
(
func
).
save
(
path
)
ts_func
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
func
(
tensor
)
ts_output
=
ts_func
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_phase_vocoder
(
self
,
test_paseudo_complex
):
def
test_phase_vocoder
(
self
,
test_paseudo_complex
):
def
func
(
tensor
):
def
func
(
tensor
):
...
@@ -610,4 +609,4 @@ class FunctionalComplex(TempDirMixin, TestBaseMixin):
...
@@ -610,4 +609,4 @@ class FunctionalComplex(TempDirMixin, TestBaseMixin):
return
F
.
phase_vocoder
(
tensor
,
rate
,
phase_advance
)
return
F
.
phase_vocoder
(
tensor
,
rate
,
phase_advance
)
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
tensor
=
torch
.
view_as_complex
(
torch
.
randn
(
2
,
1025
,
400
,
2
))
self
.
_assert_consistency
(
func
,
tensor
,
test_paseudo_complex
)
self
.
_assert_consistency
_complex
(
func
,
tensor
,
test_paseudo_complex
)
test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
torchaudio_unittest.common_utils
import
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsComplex
from
.torchscript_consistency_impl
import
Transforms
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat32
(
Transforms
,
PytorchTestCase
):
...
@@ -12,15 +12,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
...
@@ -12,15 +12,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsComplex64
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cpu'
)
class
TestTransformsComplex128
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cpu'
)
test/torchaudio_unittest/transforms/torchscript_consistency_cuda_test.py
View file @
7d45851d
import
torch
import
torch
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
torchaudio_unittest.common_utils
import
skipIfNoCuda
,
PytorchTestCase
from
.torchscript_consistency_impl
import
Transforms
,
TransformsComplex
from
.torchscript_consistency_impl
import
Transforms
@
skipIfNoCuda
@
skipIfNoCuda
...
@@ -14,17 +14,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
...
@@ -14,17 +14,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
class
TestTransformsFloat64
(
Transforms
,
PytorchTestCase
):
dtype
=
torch
.
float64
dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestTransformsComplex64
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex64
real_dtype
=
torch
.
float32
device
=
torch
.
device
(
'cuda'
)
@
skipIfNoCuda
class
TestTransformsComplex128
(
TransformsComplex
,
PytorchTestCase
):
complex_dtype
=
torch
.
complex128
real_dtype
=
torch
.
float64
device
=
torch
.
device
(
'cuda'
)
test/torchaudio_unittest/transforms/torchscript_consistency_impl.py
View file @
7d45851d
...
@@ -26,6 +26,22 @@ class Transforms(TempDirMixin, TestBaseMixin):
...
@@ -26,6 +26,22 @@ class Transforms(TempDirMixin, TestBaseMixin):
ts_output
=
ts_transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
self
.
assertEqual
(
ts_output
,
output
)
def
_assert_consistency_complex
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
dtype
)
path
=
self
.
get_temp_path
(
'transform.zip'
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
def
test_Spectrogram
(
self
):
def
test_Spectrogram
(
self
):
tensor
=
torch
.
rand
((
1
,
1000
))
tensor
=
torch
.
rand
((
1
,
1000
))
self
.
_assert_consistency
(
T
.
Spectrogram
(),
tensor
)
self
.
_assert_consistency
(
T
.
Spectrogram
(),
tensor
)
...
@@ -104,35 +120,13 @@ class Transforms(TempDirMixin, TestBaseMixin):
...
@@ -104,35 +120,13 @@ class Transforms(TempDirMixin, TestBaseMixin):
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
waveform
=
common_utils
.
get_whitenoise
(
sample_rate
=
sample_rate
)
self
.
_assert_consistency
(
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
),
waveform
)
self
.
_assert_consistency
(
T
.
SpectralCentroid
(
sample_rate
=
sample_rate
),
waveform
)
class
TransformsComplex
(
TempDirMixin
,
TestBaseMixin
):
complex_dtype
=
None
real_dtype
=
None
device
=
None
def
_assert_consistency
(
self
,
transform
,
tensor
,
test_pseudo_complex
=
False
):
assert
tensor
.
is_complex
()
tensor
=
tensor
.
to
(
device
=
self
.
device
,
dtype
=
self
.
complex_dtype
)
transform
=
transform
.
to
(
device
=
self
.
device
,
dtype
=
self
.
real_dtype
)
path
=
self
.
get_temp_path
(
'transform.zip'
)
torch
.
jit
.
script
(
transform
).
save
(
path
)
ts_transform
=
torch
.
jit
.
load
(
path
)
if
test_pseudo_complex
:
tensor
=
torch
.
view_as_real
(
tensor
)
output
=
transform
(
tensor
)
ts_output
=
ts_transform
(
tensor
)
self
.
assertEqual
(
ts_output
,
output
)
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
@
parameterized
.
expand
([(
True
,
),
(
False
,
)])
def
test_TimeStretch
(
self
,
test_pseudo_complex
):
def
test_TimeStretch
(
self
,
test_pseudo_complex
):
n_freq
=
400
n_freq
=
400
hop_length
=
512
hop_length
=
512
fixed_rate
=
1.3
fixed_rate
=
1.3
tensor
=
torch
.
view_as_complex
(
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
)))
tensor
=
torch
.
view_as_complex
(
torch
.
rand
((
10
,
2
,
n_freq
,
10
,
2
)))
self
.
_assert_consistency
(
self
.
_assert_consistency
_complex
(
T
.
TimeStretch
(
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
),
T
.
TimeStretch
(
n_freq
=
n_freq
,
hop_length
=
hop_length
,
fixed_rate
=
fixed_rate
),
tensor
,
tensor
,
test_pseudo_complex
test_pseudo_complex
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment