Unverified Commit 7d45851d authored by moto's avatar moto Committed by GitHub
Browse files

Merge test classes for complex (#1491)

parent ddd2425c
......@@ -81,6 +81,14 @@ class TestBaseMixin:
super().setUp()
set_audio_backend(self.backend)
@property
def complex_dtype(self):
if self.dtype in ['float32', 'float', torch.float, torch.float32]:
return torch.cfloat
if self.dtype in ['float64', 'double', torch.double, torch.float64]:
return torch.cdouble
raise ValueError(f'No corresponding complex dtype for {self.dtype}')
class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass
......
......@@ -4,7 +4,7 @@ import unittest
from parameterized import parameterized
from torchaudio_unittest.common_utils import PytorchTestCase, TorchaudioTestCase, skipIfNoSox
from .functional_impl import Functional, FunctionalComplex, FunctionalCPUOnly
from .functional_impl import Functional, FunctionalCPUOnly
class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase):
......@@ -21,18 +21,6 @@ class TestFunctionalFloat64(Functional, PytorchTestCase):
device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
@skipIfNoSox
class TestApplyCodec(TorchaudioTestCase):
backend = "sox_io"
......
......@@ -2,7 +2,7 @@ import torch
import unittest
from torchaudio_unittest.common_utils import PytorchTestCase, skipIfNoCuda
from .functional_impl import Functional, FunctionalComplex
from .functional_impl import Functional
@skipIfNoCuda
......@@ -19,17 +19,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestLFilterFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
......@@ -259,12 +259,6 @@ class Functional(TestBaseMixin):
self.assertEqual(specgrams, specgrams_copy)
class FunctionalComplex(TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
......@@ -286,7 +280,7 @@ class FunctionalComplex(TestBaseMixin):
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]
dtype=self.dtype, device=self.device)[..., None]
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
......
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalComplex
from .torchscript_consistency_impl import Functional
class TestFunctionalFloat32(Functional, PytorchTestCase):
......@@ -12,15 +12,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Functional, FunctionalComplex
from .torchscript_consistency_impl import Functional
@skipIfNoCuda
......@@ -14,17 +14,3 @@ class TestFunctionalFloat32(Functional, PytorchTestCase):
class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
......@@ -32,6 +32,25 @@ class Functional(TempDirMixin, TestBaseMixin):
output = output.shape
self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
torch.random.manual_seed(40)
output = func(tensor)
torch.random.manual_seed(40)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)
def test_spectrogram(self):
def func(tensor):
n_fft = 400
......@@ -572,26 +591,6 @@ class Functional(TempDirMixin, TestBaseMixin):
tensor = common_utils.get_whitenoise(sample_rate=44100)
self._assert_consistency(func, tensor)
class FunctionalComplex(TempDirMixin, TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
def _assert_consistency(self, func, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
path = self.get_temp_path('func.zip')
torch.jit.script(func).save(path)
ts_func = torch.jit.load(path)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = func(tensor)
ts_output = ts_func(tensor)
self.assertEqual(ts_output, output)
@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def func(tensor):
......@@ -610,4 +609,4 @@ class FunctionalComplex(TempDirMixin, TestBaseMixin):
return F.phase_vocoder(tensor, rate, phase_advance)
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency(func, tensor, test_paseudo_complex)
self._assert_consistency_complex(func, tensor, test_paseudo_complex)
import torch
from torchaudio_unittest.common_utils import PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsComplex
from .torchscript_consistency_impl import Transforms
class TestTransformsFloat32(Transforms, PytorchTestCase):
......@@ -12,15 +12,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cpu')
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cpu')
import torch
from torchaudio_unittest.common_utils import skipIfNoCuda, PytorchTestCase
from .torchscript_consistency_impl import Transforms, TransformsComplex
from .torchscript_consistency_impl import Transforms
@skipIfNoCuda
......@@ -14,17 +14,3 @@ class TestTransformsFloat32(Transforms, PytorchTestCase):
class TestTransformsFloat64(Transforms, PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@skipIfNoCuda
class TestTransformsComplex64(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex64
real_dtype = torch.float32
device = torch.device('cuda')
@skipIfNoCuda
class TestTransformsComplex128(TransformsComplex, PytorchTestCase):
complex_dtype = torch.complex128
real_dtype = torch.float64
device = torch.device('cuda')
......@@ -26,6 +26,22 @@ class Transforms(TempDirMixin, TestBaseMixin):
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)
def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(), tensor)
......@@ -104,35 +120,13 @@ class Transforms(TempDirMixin, TestBaseMixin):
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
class TransformsComplex(TempDirMixin, TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
def _assert_consistency(self, transform, tensor, test_pseudo_complex=False):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.real_dtype)
path = self.get_temp_path('transform.zip')
torch.jit.script(transform).save(path)
ts_transform = torch.jit.load(path)
if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor)
ts_output = ts_transform(tensor)
self.assertEqual(ts_output, output)
@parameterized.expand([(True, ), (False, )])
def test_TimeStretch(self, test_pseudo_complex):
n_freq = 400
hop_length = 512
fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
self._assert_consistency(
self._assert_consistency_complex(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
test_pseudo_complex
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment