Unverified Commit b5d80279 authored by Mark Saroufim's avatar Mark Saroufim Committed by GitHub
Browse files

Run functional tests on GPU as well as CPU (#1475)

parent bdd7b33b
...@@ -16,18 +16,18 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase): ...@@ -16,18 +16,18 @@ class TestFunctionalFloat32(Functional, FunctionalCPUOnly, PytorchTestCase):
super().test_lfilter_9th_order_filter_stability() super().test_lfilter_9th_order_filter_stability()
class TestFunctionalFloat64(Functional, FunctionalCPUOnly, PytorchTestCase): class TestFunctionalFloat64(Functional, PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
class TestFunctionalComplex64(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase): class TestFunctionalComplex64(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex64 complex_dtype = torch.complex64
real_dtype = torch.float32 real_dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
class TestFunctionalComplex128(FunctionalComplex, FunctionalCPUOnly, PytorchTestCase): class TestFunctionalComplex128(FunctionalComplex, PytorchTestCase):
complex_dtype = torch.complex128 complex_dtype = torch.complex128
real_dtype = torch.float64 real_dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
......
...@@ -93,73 +93,17 @@ class Functional(TestBaseMixin): ...@@ -93,73 +93,17 @@ class Functional(TestBaseMixin):
spec.sum().backward() spec.sum().backward()
assert not x.grad.isnan().sum() assert not x.grad.isnan().sum()
class FunctionalComplex(TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2
torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000)
assert len(w) == 0
def test_create_fb_matrix_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000)
assert len(w) == 0
def test_create_fb_matrix_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000)
assert len(w) == 1
def test_compute_deltas_one_channel(self): def test_compute_deltas_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]]) specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]]) expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
computed = F.compute_deltas(specgram, win_length=3) computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
def test_compute_deltas_two_channels(self): def test_compute_deltas_two_channels(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0], specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0],
[1.0, 2.0, 3.0, 4.0]]]) [1.0, 2.0, 3.0, 4.0]]], dtype=self.dtype, device=self.device)
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5], expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
[0.5, 1.0, 1.0, 0.5]]]) [0.5, 1.0, 1.0, 0.5]]], dtype=self.dtype, device=self.device)
computed = F.compute_deltas(specgram, win_length=3) computed = F.compute_deltas(specgram, win_length=3)
self.assertEqual(computed, expected) self.assertEqual(computed, expected)
...@@ -190,7 +134,7 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -190,7 +134,7 @@ class FunctionalCPUOnly(TestBaseMixin):
db_mult = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
torch.manual_seed(0) torch.manual_seed(0)
spec = torch.rand(*shape) * 200 spec = torch.rand(*shape, dtype=self.dtype, device=self.device) * 200
# Spectrogram amplitude -> DB -> amplitude # Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None) db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
...@@ -218,7 +162,7 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -218,7 +162,7 @@ class FunctionalCPUOnly(TestBaseMixin):
# each spectrogram still need to be predictable. The max determines the # each spectrogram still need to be predictable. The max determines the
# decibel cutoff, and the distance from the min must be large enough # decibel cutoff, and the distance from the min must be large enough
# that it triggers a clamp. # that it triggers a clamp.
spec = torch.rand(*shape) spec = torch.rand(*shape, dtype=self.dtype, device=self.device)
# Ensure each spectrogram has a min of 0 and a max of 1. # Ensure each spectrogram has a min of 0 and a max of 1.
spec -= spec.amin([-2, -1])[..., None, None] spec -= spec.amin([-2, -1])[..., None, None]
spec /= spec.amax([-2, -1])[..., None, None] spec /= spec.amax([-2, -1])[..., None, None]
...@@ -245,7 +189,7 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -245,7 +189,7 @@ class FunctionalCPUOnly(TestBaseMixin):
) )
def test_complex_norm(self, shape, power): def test_complex_norm(self, shape, power):
torch.random.manual_seed(42) torch.random.manual_seed(42)
complex_tensor = torch.randn(*shape) complex_tensor = torch.randn(*shape, dtype=self.dtype, device=self.device)
expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2) expected_norm_tensor = complex_tensor.pow(2).sum(-1).pow(power / 2)
norm_tensor = F.complex_norm(complex_tensor, power) norm_tensor = F.complex_norm(complex_tensor, power)
self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5) self.assertEqual(norm_tensor, expected_norm_tensor, atol=1e-5, rtol=1e-5)
...@@ -255,7 +199,7 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -255,7 +199,7 @@ class FunctionalCPUOnly(TestBaseMixin):
) )
def test_mask_along_axis(self, shape, mask_param, mask_value, axis): def test_mask_along_axis(self, shape, mask_param, mask_value, axis):
torch.random.manual_seed(42) torch.random.manual_seed(42)
specgram = torch.randn(*shape) specgram = torch.randn(*shape, dtype=self.dtype, device=self.device)
mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis) mask_specgram = F.mask_along_axis(specgram, mask_param, mask_value, axis)
other_axis = 1 if axis == 2 else 2 other_axis = 1 if axis == 2 else 2
...@@ -271,7 +215,7 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -271,7 +215,7 @@ class FunctionalCPUOnly(TestBaseMixin):
@parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3]))) @parameterized.expand(list(itertools.product([100], [0., 30.], [2, 3])))
def test_mask_along_axis_iid(self, mask_param, mask_value, axis): def test_mask_along_axis_iid(self, mask_param, mask_value, axis):
torch.random.manual_seed(42) torch.random.manual_seed(42)
specgrams = torch.randn(4, 2, 1025, 400) specgrams = torch.randn(4, 2, 1025, 400, dtype=self.dtype, device=self.device)
mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis) mask_specgrams = F.mask_along_axis_iid(specgrams, mask_param, mask_value, axis)
...@@ -282,3 +226,59 @@ class FunctionalCPUOnly(TestBaseMixin): ...@@ -282,3 +226,59 @@ class FunctionalCPUOnly(TestBaseMixin):
assert mask_specgrams.size() == specgrams.size() assert mask_specgrams.size() == specgrams.size()
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
class FunctionalComplex(TestBaseMixin):
complex_dtype = None
real_dtype = None
device = None
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
num_frames = 400
batch_size = 2
torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)
phase_advance = torch.linspace(
0,
np.pi * hop_length,
num_freq,
dtype=self.real_dtype, device=self.device)[..., None]
spec_stretch = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
assert output_shape == expected_shape
class FunctionalCPUOnly(TestBaseMixin):
def test_create_fb_matrix_no_warning_high_n_freq(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(288, 0, 8000, 128, 16000)
assert len(w) == 0
def test_create_fb_matrix_no_warning_low_n_mels(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 89, 16000)
assert len(w) == 0
def test_create_fb_matrix_warning(self):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
F.create_fb_matrix(201, 0, 8000, 128, 16000)
assert len(w) == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment