Unverified Commit 00d38203 authored by moto's avatar moto Committed by GitHub
Browse files

Introduce common utility for defining test matrix for device/dtype (#616)

* Introduce common utility for defining test matrix for device/dtype

* Make resample_waveform support float64

* Mark lfilter related test as xfail when float64

* fix
parent 7a0d4192
import os import os
import tempfile import tempfile
from typing import Type, Iterable
from contextlib import contextmanager from contextlib import contextmanager
from shutil import copytree from shutil import copytree
import torch import torch
import torchaudio import torchaudio
import pytest
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) _TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends BACKENDS = torchaudio._backend._audio_backends
...@@ -78,3 +80,37 @@ def filter_backends_with_mp3(backends): ...@@ -78,3 +80,37 @@ def filter_backends_with_mp3(backends):
BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS) BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)
class TestBaseMixin:
dtype = None
device = None
def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str):
if dtype not in ['float32', 'float64']:
raise NotImplementedError(f'Unexpected dtype: {dtype}')
if device not in ['cpu', 'cuda']:
raise NotImplementedError(f'Unexpected device: {device}')
name = f'Test{testbase.__name__}_{device.upper()}_{dtype.capitalize()}'
attrs = {'dtype': getattr(torch, dtype), 'device': torch.device(device)}
testsuite = type(name, (testbase,), attrs)
if device == 'cuda':
testsuite = pytest.mark.skipif(
not torch.cuda.is_available(), reason='CUDA not available')(testsuite)
return testsuite
def define_test_suites(
scope: dict,
testbases: Iterable[Type[TestBaseMixin]],
dtypes: Iterable[str] = ('float32', 'float64'),
devices: Iterable[str] = ('cpu', 'cuda'),
):
for suite in testbases:
for device in devices:
for dtype in dtypes:
t = define_test_suite(suite, dtype, device)
scope[t.__name__] = t
...@@ -9,10 +9,7 @@ import pytest ...@@ -9,10 +9,7 @@ import pytest
import common_utils import common_utils
class _LfilterMixin: class Lfilter(common_utils.TestBaseMixin):
device = None
dtype = None
def test_simple(self): def test_simple(self):
""" """
Create a very basic signal, Create a very basic signal,
...@@ -33,25 +30,12 @@ class _LfilterMixin: ...@@ -33,25 +30,12 @@ class _LfilterMixin:
b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device) b_coeffs = torch.tensor([1, 0], dtype=self.dtype, device=self.device)
a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device) a_coeffs = torch.tensor([1, -0.95], dtype=self.dtype, device=self.device)
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True) output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=True)
self.assertTrue(output_signal.max() <= 1) assert output_signal.max() <= 1
output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False) output_signal = F.lfilter(input_signal, a_coeffs, b_coeffs, clamp=False)
self.assertTrue(output_signal.max() > 1) assert output_signal.max() > 1
class TestLfilterFloat32CPU(_LfilterMixin, unittest.TestCase):
device = torch.device('cpu')
dtype = torch.float32
class TestLfilterFloat64CPU(_LfilterMixin, unittest.TestCase):
device = torch.device('cpu')
dtype = torch.float64
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") common_utils.define_test_suites(globals(), [Lfilter])
class TestLfilterFloat32CUDA(_LfilterMixin, unittest.TestCase):
device = torch.device('cuda')
dtype = torch.float32
class TestComputeDeltas(unittest.TestCase): class TestComputeDeltas(unittest.TestCase):
......
"""Test suites for jit-ability and its numerical compatibility""" """Test suites for jit-ability and its numerical compatibility"""
import unittest import unittest
import pytest
import torch import torch
import torchaudio import torchaudio
...@@ -9,8 +10,7 @@ import torchaudio.transforms as T ...@@ -9,8 +10,7 @@ import torchaudio.transforms as T
import common_utils import common_utils
def _assert_functional_consistency(func, tensor, device, shape_only=False): def _assert_functional_consistency(func, tensor, shape_only=False):
tensor = tensor.to(device)
ts_func = torch.jit.script(func) ts_func = torch.jit.script(func)
output = func(tensor) output = func(tensor)
ts_output = ts_func(tensor) ts_output = ts_func(tensor)
...@@ -21,21 +21,18 @@ def _assert_functional_consistency(func, tensor, device, shape_only=False): ...@@ -21,21 +21,18 @@ def _assert_functional_consistency(func, tensor, device, shape_only=False):
torch.testing.assert_allclose(ts_output, output) torch.testing.assert_allclose(ts_output, output)
def _assert_transforms_consistency(transform, tensor, device): def _assert_transforms_consistency(transform, tensor):
tensor = tensor.to(device)
transform = transform.to(device)
ts_transform = torch.jit.script(transform) ts_transform = torch.jit.script(transform)
output = transform(tensor) output = transform(tensor)
ts_output = ts_transform(tensor) ts_output = ts_transform(tensor)
torch.testing.assert_allclose(ts_output, output) torch.testing.assert_allclose(ts_output, output)
class _FunctionalTestMixin: class Functional(common_utils.TestBaseMixin):
"""Implements test for `functinoal` modul that are performed for different devices""" """Implements test for `functinoal` modul that are performed for different devices"""
device = None
def _assert_consistency(self, func, tensor, shape_only=False): def _assert_consistency(self, func, tensor, shape_only=False):
return _assert_functional_consistency(func, tensor, self.device, shape_only=shape_only) tensor = tensor.to(device=self.device, dtype=self.dtype)
return _assert_functional_consistency(func, tensor, shape_only=shape_only)
def test_spectrogram(self): def test_spectrogram(self):
def func(tensor): def func(tensor):
...@@ -159,7 +156,7 @@ class _FunctionalTestMixin: ...@@ -159,7 +156,7 @@ class _FunctionalTestMixin:
return F.complex_norm(tensor, power) return F.complex_norm(tensor, power)
tensor = torch.randn(1, 2, 1025, 400, 2) tensor = torch.randn(1, 2, 1025, 400, 2)
_assert_functional_consistency(func, tensor, self.device) self._assert_consistency(func, tensor)
def test_mask_along_axis(self): def test_mask_along_axis(self):
def func(tensor): def func(tensor):
...@@ -211,6 +208,9 @@ class _FunctionalTestMixin: ...@@ -211,6 +208,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, tensor, shape_only=True) self._assert_consistency(func, tensor, shape_only=True)
def test_lfilter(self): def test_lfilter(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -252,6 +252,9 @@ class _FunctionalTestMixin: ...@@ -252,6 +252,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_lowpass(self): def test_lowpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -263,6 +266,9 @@ class _FunctionalTestMixin: ...@@ -263,6 +266,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_highpass(self): def test_highpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -274,6 +280,9 @@ class _FunctionalTestMixin: ...@@ -274,6 +280,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_allpass(self): def test_allpass(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path('whitenoise.wav') filepath = common_utils.get_asset_path('whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -286,6 +295,9 @@ class _FunctionalTestMixin: ...@@ -286,6 +295,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_bandpass_with_csg(self): def test_bandpass_with_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -298,7 +310,10 @@ class _FunctionalTestMixin: ...@@ -298,7 +310,10 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_bandpass_withou_csg(self): def test_bandpass_without_csg(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -312,6 +327,9 @@ class _FunctionalTestMixin: ...@@ -312,6 +327,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_bandreject(self): def test_bandreject(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -324,6 +342,9 @@ class _FunctionalTestMixin: ...@@ -324,6 +342,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_band_with_noise(self): def test_band_with_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -337,6 +358,9 @@ class _FunctionalTestMixin: ...@@ -337,6 +358,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_band_without_noise(self): def test_band_without_noise(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -350,6 +374,9 @@ class _FunctionalTestMixin: ...@@ -350,6 +374,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_treble(self): def test_treble(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -363,6 +390,9 @@ class _FunctionalTestMixin: ...@@ -363,6 +390,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_deemph(self): def test_deemph(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -373,6 +403,9 @@ class _FunctionalTestMixin: ...@@ -373,6 +403,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_riaa(self): def test_riaa(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -383,6 +416,9 @@ class _FunctionalTestMixin: ...@@ -383,6 +416,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_equalizer(self): def test_equalizer(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -396,6 +432,9 @@ class _FunctionalTestMixin: ...@@ -396,6 +432,9 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
def test_perf_biquad_filtering(self): def test_perf_biquad_filtering(self):
if self.dtype == torch.float64:
pytest.xfail("This test is known to fail for float64")
filepath = common_utils.get_asset_path("whitenoise.wav") filepath = common_utils.get_asset_path("whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
...@@ -489,12 +528,12 @@ class _FunctionalTestMixin: ...@@ -489,12 +528,12 @@ class _FunctionalTestMixin:
self._assert_consistency(func, waveform) self._assert_consistency(func, waveform)
class _TransformsTestMixin: class Transforms(common_utils.TestBaseMixin):
"""Implements test for Transforms that are performed for different devices""" """Implements test for Transforms that are performed for different devices"""
device = None
def _assert_consistency(self, transform, tensor): def _assert_consistency(self, transform, tensor):
_assert_transforms_consistency(transform, tensor, self.device) tensor = tensor.to(device=self.device, dtype=self.dtype)
transform = transform.to(device=self.device, dtype=self.dtype)
_assert_transforms_consistency(transform, tensor)
def test_Spectrogram(self): def test_Spectrogram(self):
tensor = torch.rand((1, 1000)) tensor = torch.rand((1, 1000))
...@@ -578,27 +617,4 @@ class _TransformsTestMixin: ...@@ -578,27 +617,4 @@ class _TransformsTestMixin:
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform) self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)
class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase): common_utils.define_test_suites(globals(), [Functional, Transforms])
"""Test suite for Functional module on CPU"""
device = torch.device('cpu')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestFunctionalCUDA(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on GPU"""
device = torch.device('cuda')
class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on CPU"""
device = torch.device('cpu')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestTransformsCUDA(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on GPU"""
device = torch.device('cuda')
if __name__ == '__main__':
unittest.main()
...@@ -890,6 +890,8 @@ def resample_waveform(waveform: Tensor, ...@@ -890,6 +890,8 @@ def resample_waveform(waveform: Tensor,
Returns: Returns:
Tensor: The waveform at the new frequency Tensor: The waveform at the new frequency
""" """
device, dtype = waveform.device, waveform.dtype
assert waveform.dim() == 2 assert waveform.dim() == 2
assert orig_freq > 0.0 and new_freq > 0.0 assert orig_freq > 0.0 and new_freq > 0.0
...@@ -905,7 +907,7 @@ def resample_waveform(waveform: Tensor, ...@@ -905,7 +907,7 @@ def resample_waveform(waveform: Tensor,
window_width = lowpass_filter_width / (2.0 * lowpass_cutoff) window_width = lowpass_filter_width / (2.0 * lowpass_cutoff)
first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit, first_indices, weights = _get_LR_indices_and_weights(orig_freq, new_freq, output_samples_in_unit,
window_width, lowpass_cutoff, lowpass_filter_width) window_width, lowpass_cutoff, lowpass_filter_width)
weights = weights.to(waveform.device) # TODO Create weights on device directly weights = weights.to(device=device, dtype=dtype) # TODO Create weights on device directly
assert first_indices.dim() == 1 assert first_indices.dim() == 1
# TODO figure a better way to do this. conv1d reaches every element i*stride + padding # TODO figure a better way to do this. conv1d reaches every element i*stride + padding
...@@ -918,9 +920,9 @@ def resample_waveform(waveform: Tensor, ...@@ -918,9 +920,9 @@ def resample_waveform(waveform: Tensor,
window_size = weights.size(1) window_size = weights.size(1)
tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq) tot_output_samp = _get_num_LR_output_samples(wave_len, orig_freq, new_freq)
output = torch.zeros((num_channels, tot_output_samp), output = torch.zeros((num_channels, tot_output_samp),
device=waveform.device) device=device, dtype=dtype)
# eye size: (num_channels, num_channels, 1) # eye size: (num_channels, num_channels, 1)
eye = torch.eye(num_channels, device=waveform.device).unsqueeze(2) eye = torch.eye(num_channels, device=device, dtype=dtype).unsqueeze(2)
for i in range(first_indices.size(0)): for i in range(first_indices.size(0)):
wave_to_conv = waveform wave_to_conv = waveform
first_index = int(first_indices[i].item()) first_index = int(first_indices[i].item())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment