Unverified Commit 00d38203 authored by moto's avatar moto Committed by GitHub
Browse files

Introduce common utility for defining test matrix for device/dtype (#616)

* Introduce common utility for defining test matrix for device/dtype

* Make resample_waveform support float64

* Mark lfilter related test as xfail when float64

* fix
parent 7a0d4192
import os
import tempfile
from typing import Type, Iterable
from contextlib import contextmanager
from shutil import copytree
import torch
import torchaudio
import pytest
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends
......@@ -78,3 +80,37 @@ def filter_backends_with_mp3(backends):
BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)
class TestBaseMixin:
dtype = None
device = None
def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
if dtype not in ['float32', 'float64']:
raise NotImplementedError(f'Unexpected dtype: {dtype}')
if device not in ['cpu', 'cuda']:
raise NotImplementedError(f'Unexpected device: {device}')
name = f'Test{testbase.__name__}_{device.upper()}_{dtype.capitalize()}'
attrs = {'dtype': getattr(torch, dtype), 'device': torch.device(device)}
testsuite = type(name, (testbase,), attrs)
if device == 'cuda':
testsuite = pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')(testsuite)
return testsuite
def define_test_suites(
scope: dict,
testbases: Iterable[Type[TestBaseMixin]],
dtypes: Iterable[str] = ('float32', 'float64'),
devices: Iterable[str] = ('cpu', 'cuda'),
):
for suite in testbases:
for device in devices:
for dtype in dtypes:
t = define_test_suite(suite, dtype, device)
scope[t.__name__] = t
......@@ -9,10 +9,7 @@ import pytest
import common_utils
class _LfilterMixin:
device = None
dtype = None
class Lfilter(common_utils.TestBaseMixin):
def test_simple(self):
"""
Create a very basic signal,
......@@ -33,25 +30,12 @@ class _LfilterMixin:
b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True)
self.assertTrue(output_signal.max() <= 1)
assert output_signal.max() <= 1
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
self.assertTrue(output_signal.max() > 1)
class TestLfilterFloat32CPU(_LfilterMixin, unittest.TestCase):
device = torch.device('cpu')
dtype = torch.float32
class TestLfilterFloat64CPU(_LfilterMixin, unittest.TestCase):
device = torch.device('cpu')
dtype = torch.float64
assert output_signal.max() > 1
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
class TestLfilterFloat32CUDA(_LfilterMixin, unittest.TestCase):
device = torch.device('cuda')
dtype = torch.float32
common_utils.define_test_suites(globals(), [Lfilter])
class TestComputeDeltas(unittest.TestCase):
......
"""Test suites for jit-ability and its numerical compatibility"""
import unittest
import pytest
import torch
import torchaudio
......@@ -9,8 +10,7 @@ import torchaudio.transforms as T
import common_utils
def _assert_functional_consistency(func, tensor, device, shape_only=False):
tensor = tensor.to(device)
def _assert_functional_consistency(func, tensor, shape_only=False):
ts_func = torch.jit.script(func)
output = func(tensor)
ts_output = ts_func(tensor)
......@@ -21,21 +21,18 @@ def _assert_functional_consistency(func, tensor, device, shape_only=False):
torch.testing.assert_allclose(ts_output, output)
def _assert_transforms_consistency(transform, tensor, device):
tensor = tensor.to(device)
transform = transform.to(device)
def _assert_transforms_consistency(transform, tensor):
ts_transform = torch.jit.script(transform)
output = transform(tensor)
ts_output = ts_transform(tensor)
torch.testing.assert_allclose(ts_output, output)
class _FunctionalTestMixin:
class Functional(common_utils.TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices"""
device = None
def _assert_consistency(self, func, tensor, shape_only=False):
return _assert_functional_consistency(func, tensor, self.device, shape_only=shape_only)
tensor = tensor.to(device=self.device, dtype=self.dtype)
return _assert_functional_consistency(func, tensor, shape_only=shape_only)
def test_spectrogram(self):
def func(tensor):
......@@ -159,7 +156,7 @@ class _FunctionalTestMixin:
return F.complex_norm(tensor, power)
tensor = torch.randn(1, 2, 1025, 400, 2)
_assert_functional_consistency(func, tensor, self.device)
self._assert_consistency(func, tensor)
def test_mask_along_axis(self):
def func(tensor):
......@@ -211,6 +208,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, tensor, shape_only=True)
def test_lfilter(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -252,6 +252,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_lowpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -263,6 +266,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_highpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -274,6 +280,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_allpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -286,6 +295,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_bandpass_with_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -298,7 +310,10 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_bandpass_withou_csg(self):
def test_bandpass_without_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -312,6 +327,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_bandreject(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -324,6 +342,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_band_with_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -337,6 +358,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_band_without_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -350,6 +374,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_treble(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -363,6 +390,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_deemph(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -373,6 +403,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_riaa(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -383,6 +416,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_equalizer(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -396,6 +432,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
def test_perf_biquad_filtering(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
......@@ -489,12 +528,12 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform)
class _TransformsTestMixin:
class Transforms(common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices"""
device = None
def _assert_consistency(self, transform, tensor):
_assert_transforms_consistency(transform, tensor, self.device)
tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
_assert_transforms_consistency(transform, tensor)
def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
......@@ -578,27 +617,4 @@ class _TransformsTestMixin:
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)
class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on CPU"""
device = torch.device('cpu')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestFunctionalCUDA(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on GPU"""
device = torch.device('cuda')
class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on CPU"""
device = torch.device('cpu')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestTransformsCUDA(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on GPU"""
device = torch.device('cuda')
if __name__ == '__main__':
unittest.main()
common_utils.define_test_suites(globals(), [Functional, Transforms])
......@@ -890,6 +890,8 @@ def resample_waveform(waveform: Tensor,
Returns:
Tensor: The waveform at the new frequency
"""
device, dtype = waveform.device, waveform.dtype
assert waveform.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0
......@@ -905,7 +907,7 @@ def resample_waveform(waveform: Tensor,
window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width)
weights = weights.to(waveform.device) # TODO Create weights on device directly
weights = weights.to(device=device, dtype=dtype) # TODO Create weights on device directly
assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding
......@@ -918,9 +920,9 @@ def resample_waveform(waveform: Tensor,
window_size = weights.size(1)
tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq)
output = torch.zeros((num_channels, tot_output_samp),
device=waveform.device)
device=device, dtype=dtype)
# eye size: (num_channels, num_channels, 1)
eye = torch.eye(num_channels, device=waveform.device).unsqueeze(2)
eye = torch.eye(num_channels, device=device, dtype=dtype).unsqueeze(2)
for i in range(first_indices.size(0)):
wave_to_conv = waveform
first_index = int(first_indices[i].item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment