Unverified Commit e9726f08 authored by steveplazafb's avatar steveplazafb Committed by GitHub
Browse files

Refactors functional test (#1435)

Merges lfilter and spectrogram classes together in the common implementation and modifies the cpu and gpu test definitions accordingly
parent 20469cfe
......@@ -14,29 +14,19 @@ from torchaudio_unittest.common_utils import (
skipIfNoSox,
)
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
from .functional_impl import Functional, FunctionalComplex
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
@unittest.expectedFailure
def test_9th_order_filter_stability(self):
super().test_9th_order_filter_stability()
def test_lfilter_9th_order_filter_stability(self):
super().test_lfilter_9th_order_filter_stability()
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
class TestFunctionalFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
......
......@@ -2,33 +2,21 @@ import torch
import unittest
from torchaudio_unittest import common_utils
from .functional_impl import Lfilter, Spectrogram, FunctionalComplex
from .functional_impl import Functional, FunctionalComplex
@common_utils.skipIfNoCuda
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
class TestFunctionalloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@unittest.expectedFailure
def test_9th_order_filter_stability(self):
super().test_9th_order_filter_stability()
def test_lfilter_9th_order_filter_stability(self):
super().test_lfilter_9th_order_filter_stability()
@common_utils.skipIfNoCuda
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat32(Spectrogram, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestSpectrogramFloat64(Spectrogram, common_utils.PytorchTestCase):
class TestLFilterFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
......
......@@ -9,8 +9,8 @@ from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import nested_params
class Lfilter(common_utils.TestBaseMixin):
def test_simple(self):
class Functional(common_utils.TestBaseMixin):
def test_lfilter_simple(self):
"""
Create a very basic signal,
Then make a simple 4th order delay
......@@ -25,7 +25,7 @@ class Lfilter(common_utils.TestBaseMixin):
self.assertEqual(output_waveform[:, 3:], waveform[:, 0:-3], atol=1e-5, rtol=1e-5)
def test_clamp(self):
def test_lfilter_clamp(self):
input_signal = torch.ones(1, 44100 * 1, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
......@@ -40,7 +40,7 @@ class Lfilter(common_utils.TestBaseMixin):
((2, 3, 44100),),
((1, 2, 3, 44100),)
])
def test_shape(self, shape):
def test_lfilter_shape(self, shape):
torch.random.manual_seed(42)
waveform = torch.rand(*shape, dtype=self.dtype, device=self.device)
b_coeffs = torch.tensor([0, 0, 0, 1], dtype=self.dtype, device=self.device)
......@@ -48,7 +48,7 @@ class Lfilter(common_utils.TestBaseMixin):
output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs)
assert shape == waveform.size() == output_waveform.size()
def test_9th_order_filter_stability(self):
def test_lfilter_9th_order_filter_stability(self):
"""
Validate the precision of lfilter against reference scipy implementation when using high order filter.
The reference implementation use cascaded second-order filters so is more numerically accurate.
......@@ -70,10 +70,8 @@ class Lfilter(common_utils.TestBaseMixin):
yhat = F.lfilter(x, a, b, False)
self.assertEqual(yhat, y, atol=1e-4, rtol=1e-5)
class Spectrogram(common_utils.TestBaseMixin):
@parameterized.expand([(0., ), (1., ), (2., ), (3., )])
def test_grad_at_zero(self, power):
def test_spectogram_grad_at_zero(self, power):
"""The gradient of power spectrogram should not be nan but zero near x=0
https://github.com/pytorch/audio/issues/993
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment