Unverified Commit e9726f08 authored by steveplazafb's avatar steveplazafb Committed by GitHub
Browse files

Refactors functional test (#1435)

Merges lfilter and spectrogram classes together in the common implementation and modifies the cpu and gpu test definitions accordingly
parent 20469cfe
...@@ -14,29 +14,19 @@ from torchaudio_unittest.common_utils import ( ...@@ -14,29 +14,19 @@ from torchaudio_unittest.common_utils import (
skipIfNoSox, skipIfNoSox,
) )
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex from .functional_impl import Functional, FunctionalComplex
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
@unittest.expectedFailure @unittest.expectedFailure
def test_9th_order_filter_stability(self): def test_lfilter_9th_order_filter_stability(self):
super().test_9th_order_filter_stability() super().test_lfilter_9th_order_filter_stability()
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): class TestFunctionalFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cpu') device = torch.device('cpu')
......
...@@ -2,33 +2,21 @@ import torch ...@@ -2,33 +2,21 @@ import torch
import unittest import unittest
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex from .functional_impl import Functional, FunctionalComplex
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): class TestFunctionalloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cuda') device = torch.device('cuda')
@unittest.expectedFailure @unittest.expectedFailure
def test_9th_order_filter_stability(self): def test_lfilter_9th_order_filter_stability(self):
super().test_9th_order_filter_stability() super().test_lfilter_9th_order_filter_stability()
@common_utils.skipIfNoCuda @common_utils.skipIfNoCuda
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float64 dtype = torch.float64
device = torch.device('cuda') device = torch.device('cuda')
......
...@@ -9,8 +9,8 @@ from torchaudio_unittest import common_utils ...@@ -9,8 +9,8 @@ from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params from torchaudio_unittest.common_utils import nested_params
class Lfilter(common_utils.TestBaseMixin): class Functional(common_utils.TestBaseMixin):
def test_simple(self): def test_lfilter_simple(self):
""" """
Create a very basic signal, Create a very basic signal,
Then make a simple 4th order delay Then make a simple 4th order delay
...@@ -25,7 +25,7 @@ class Lfilter(common_utils.TestBaseMixin): ...@@ -25,7 +25,7 @@ class Lfilter(common_utils.TestBaseMixin):
self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5) self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)
def test_clamp(self): def test_lfilter_clamp(self):
input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device) input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device) b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device) a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
...@@ -40,7 +40,7 @@ class Lfilter(common_utils.TestBaseMixin): ...@@ -40,7 +40,7 @@ class Lfilter(common_utils.TestBaseMixin):
((2, 3, 44100),), ((2, 3, 44100),),
((1, 2, 3, 44100),) ((1, 2, 3, 44100),)
]) ])
def test_shape(self, shape): def test_lfilter_shape(self, shape):
torch.random.manual_seed(42) torch.random.manual_seed(42)
waveform = torch.rand(*shape, dtype=self.dtype, device=self.device) waveform = torch.rand(*shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device) b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
...@@ -48,7 +48,7 @@ class Lfilter(common_utils.TestBaseMixin): ...@@ -48,7 +48,7 @@ class Lfilter(common_utils.TestBaseMixin):
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size() assert shape == waveform.size() == output_waveform.size()
def test_9th_order_filter_stability(self): def test_lfilter_9th_order_filter_stability(self):
""" """
Validate the precision of lfilter against reference scipy implementation when using high order filter. Validate the precision of lfilter against reference scipy implementation when using high order filter.
The reference implementation use cascaded second-order filters so is more numerically accurate. The reference implementation use cascaded second-order filters so is more numerically accurate.
...@@ -70,10 +70,8 @@ class Lfilter(common_utils.TestBaseMixin): ...@@ -70,10 +70,8 @@ class Lfilter(common_utils.TestBaseMixin):
yhat = F.lfilter(x, a, b, False) yhat = F.lfilter(x, a, b, False)
self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5) self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5)
class Spectrogram(common_utils.TestBaseMixin):
@parameterized.expand([(0., ), (1., ), (2., ), (3., )]) @parameterized.expand([(0., ), (1., ), (2., ), (3., )])
def test_grad_at_zero(self, power): def test_spectogram_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0 """The gradient of power spectrogram should not be nan but zero near x=0
https://github.com/pytorch/audio/issues/993 https://github.com/pytorch/audio/issues/993
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment