Unverified Commit b17da7a4 authored by moto's avatar moto Committed by GitHub
Browse files

Make TestCases backend-aware (#719)

* Make tests backend aware by introducing TorchaudioTestCase and reset backend for each TestCase.

* Set backends for the test cases that require specific backend.
parent 03da871f
......@@ -44,6 +44,16 @@ The following test modules are defined for corresponding `torchaudio` module/fun
## Adding test
The following is the current practice of torchaudio test suite.
1. Unless the tests are related to I/O, use synthetic data. [`common_utils`](./common_utils.py) has some data generator functions.
1. When you add a new test case, use `common_utils.TorchaudioTestCase` as base class unless you are writing tests that are common to CPU / CUDA.
- Set class memeber `dtype`, `device` and `backend` for the desired behavior.
- If you do not set `backend` value in your test suite, then I/O functions will be unassigned and attempt to load/save file will fail.
- For `backend` value, in addition to available backends, you can also provide the value "default" and backend will be picked automatically based on availability.
1. If you are writing tests that should pass on diffrent dtype/devices, write a common class inheriting `common_utils.TestBaseMixin`, then inherit `common_utils.PytorchTestCase` and define class attributes (`dtype` / `device` / `backend`) there. See [Torchscript consistency test implementation](./torchscript_consistency_impl.py) and test definitions for [CPU](./torchscript_consistency_cpu_test.py) and [CUDA](./torchscript_consistency_cuda_test.py) devices.
1. For numerically comparing Tensors, use `assertEqual` method from `common_utils.PytorchTestCase` class. This method has a better support for a wide variety of Tensor types.
When you add a new feature(functional/transform), consider the following
1. When you add a new feature, please make it Torchscript-able and batch-consistent unless it degrades the performance. Please add the tests to see if the new feature meet these requirements.
......
import os
import tempfile
import unittest
from typing import Iterable, Union
from contextlib import contextmanager
from typing import Union
from shutil import copytree
import torch
from torch.testing._internal.common_utils import TestCase
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
......@@ -55,24 +54,14 @@ def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
return torch.tensor(arr).float().view(size) / m
@contextmanager
def AudioBackendScope(new_backend):
previous_backend = torchaudio.get_audio_backend()
try:
torchaudio.set_audio_backend(new_backend)
yield
finally:
torchaudio.set_audio_backend(previous_backend)
def filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3
test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3')
def supports_mp3(backend):
torchaudio.set_audio_backend(backend)
try:
with AudioBackendScope(backend):
torchaudio.load(test_filepath)
torchaudio.load(test_filepath)
return True
except (RuntimeError, ImportError):
return False
......@@ -83,21 +72,38 @@ def filter_backends_with_mp3(backends):
BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)
def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend
torchaudio.set_audio_backend(be)
class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
device = None
backend = None
def setUp(self):
super().setUp()
set_audio_backend(self.backend)
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass
def common_test_class_parameters(
dtypes: Iterable[str] = ("float32", "float64"),
devices: Iterable[str] = ("cpu", "cuda"),
):
for device in devices:
for dtype in dtypes:
yield {"device": torch.device(device), "dtype": getattr(torch, dtype)}
skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
def get_whitenoise(
......
......@@ -10,17 +10,17 @@ from . import common_utils
from .functional_impl import Lfilter
class TestLFilterFloat32(Lfilter, common_utils.TestCase):
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestLFilterFloat64(Lfilter, common_utils.TestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestComputeDeltas(unittest.TestCase):
class TestComputeDeltas(common_utils.TorchaudioTestCase):
"""Test suite for correctness of compute_deltas"""
def test_one_channel(self):
specgram = torch.tensor([[[1.0, 2.0, 3.0, 4.0]]])
......@@ -57,7 +57,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
_compare_estimate(sound, estimate)
class TestIstft(unittest.TestCase):
class TestIstft(common_utils.TorchaudioTestCase):
"""Test suite for correctness of istft with various input"""
number_of_trials = 100
......@@ -273,7 +273,9 @@ class TestIstft(unittest.TestCase):
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
class TestDetectPitchFrequency(unittest.TestCase):
class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
backend = 'default'
def test_pitch(self):
test_filepath_100 = common_utils.get_asset_path("100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = common_utils.get_asset_path("440Hz_44100Hz_16bit_05sec.wav")
......@@ -294,7 +296,7 @@ class TestDetectPitchFrequency(unittest.TestCase):
self.assertFalse(s)
class TestDB_to_amplitude(unittest.TestCase):
class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
def test_DB_to_amplitude(self):
# Make some noise
x = torch.rand(1000)
......
......@@ -5,12 +5,12 @@ from .functional_impl import Lfilter
@common_utils.skipIfNoCuda
class TestLFilterFloat32(Lfilter, common_utils.TestCase):
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestLFilterFloat64(Lfilter, common_utils.TestCase):
class TestLFilterFloat64(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
......@@ -4,11 +4,11 @@ from . import common_utils
from .kaldi_compatibility_impl import Kaldi
class TestKaldiFloat32(Kaldi, common_utils.TestCase):
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestKaldiFloat64(Kaldi, common_utils.TestCase):
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
......@@ -5,12 +5,12 @@ from .kaldi_compatibility_impl import Kaldi
@common_utils.skipIfNoCuda
class TestKaldiFloat32(Kaldi, common_utils.TestCase):
class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestKaldiFloat64(Kaldi, common_utils.TestCase):
class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
......@@ -55,6 +55,8 @@ def _load_params(path):
class Kaldi(common_utils.TestBaseMixin):
backend = 'sox'
def assert_equal(self, output, *, expected, rtol=None, atol=None):
expected = expected.to(dtype=self.dtype, device=self.device)
self.assertEqual(output, expected, rtol=rtol, atol=atol)
......
......@@ -3,8 +3,10 @@ import unittest
import torchaudio
from torchaudio._internal.module_utils import is_module_available
from . import common_utils
class BackendSwitch:
class BackendSwitchMixin:
"""Test set/get_audio_backend works"""
backend = None
backend_module = None
......@@ -21,7 +23,7 @@ class BackendSwitch:
assert torchaudio.info == self.backend_module.info
class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = None
backend_module = torchaudio.backend.no_backend
......@@ -29,12 +31,12 @@ class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
class TestBackendSwitch_SoX(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend
@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
class TestBackendSwitch_soundfile(BackendSwitch, unittest.TestCase):
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
......@@ -2,14 +2,14 @@
import unittest
import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
from . import common_utils
class TestFunctional(TestCase):
class TestFunctional(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test functions defined in `functional` module"""
def assert_batch_consistency(
self, functional, tensor, *args, batch_size=1, atol=1e-8, rtol=1e-5, seed=42, **kwargs):
......@@ -98,12 +98,15 @@ class TestFunctional(TestCase):
self.assert_batch_consistencies(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
def test_vad(self):
common_utils.set_audio_backend('default')
filepath = common_utils.get_asset_path("vad-go-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self.assert_batch_consistencies(F.vad, waveform, sample_rate=sample_rate)
class TestTransforms(TestCase):
class TestTransforms(common_utils.TorchaudioTestCase):
backend = 'default'
"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
......
import math
import os
import math
import unittest
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import unittest
from . import common_utils
from .compliance import utils as compliance_utils
from .common_utils import AudioBackendScope, BACKENDS
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
......@@ -46,7 +46,10 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
window[f, s] = wave[s_in_wave]
class Test_Kaldi(unittest.TestCase):
@common_utils.skipIfNoSoxBackend
class Test_Kaldi(common_utils.TorchaudioTestCase):
backend = 'sox'
test_filepath = common_utils.get_asset_path('kaldi_file.wav')
test_8000_filepath = common_utils.get_asset_path('kaldi_file_8000.wav')
kaldi_output_dir = common_utils.get_asset_path('kaldi')
......@@ -162,8 +165,6 @@ class Test_Kaldi(unittest.TestCase):
# Passing in an empty tensor should result in an error
self.assertRaises(AssertionError, kaldi.mfcc, torch.empty(0))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_resample_waveform(self):
def get_output_fn(sound, args):
output = kaldi.resample_waveform(sound, args[1], args[2])
......
......@@ -4,7 +4,6 @@ import torchaudio
from torch.utils.data import Dataset, DataLoader
from . import common_utils
from .common_utils import AudioBackendScope, BACKENDS
class TORCHAUDIODS(Dataset):
......@@ -28,9 +27,10 @@ class TORCHAUDIODS(Dataset):
return len(self.data)
class Test_DataLoader(unittest.TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
class Test_DataLoader(common_utils.TorchaudioTestCase):
backend = 'sox'
@common_utils.skipIfNoSoxBackend
def test_1(self):
expected_size = (2, 1, 16000)
ds = TORCHAUDIODS()
......
......@@ -13,7 +13,8 @@ from torchaudio.datasets.cmuarctic import CMUARCTIC
from . import common_utils
class TestDatasets(unittest.TestCase):
class TestDatasets(common_utils.TorchaudioTestCase):
backend = 'default'
path = common_utils.get_asset_path()
def test_yesno(self):
......@@ -28,14 +29,32 @@ class TestDatasets(unittest.TestCase):
data = LIBRISPEECH(self.path, "dev-clean")
data[0]
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_ljspeech(self):
data = LJSPEECH(self.path)
data[0]
def test_speechcommands(self):
data = SPEECHCOMMANDS(self.path)
data[0]
def test_gtzan(self):
data = GTZAN(self.path)
data[0]
def test_cmuarctic(self):
data = CMUARCTIC(self.path)
data[0]
@common_utils.skipIfNoSoxBackend
class TestCommonVoice(common_utils.TorchaudioTestCase):
backend = 'sox'
path = common_utils.get_asset_path()
def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
data[0]
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_commonvoice_diskcache(self):
data = COMMONVOICE(self.path, url="tatar")
data = diskcache_iterator(data)
......@@ -44,29 +63,12 @@ class TestDatasets(unittest.TestCase):
# Load
data[0]
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_commonvoice_bg(self):
data = COMMONVOICE(self.path, url="tatar")
data = bg_iterator(data, 5)
for _ in data:
pass
def test_ljspeech(self):
data = LJSPEECH(self.path)
data[0]
def test_speechcommands(self):
data = SPEECHCOMMANDS(self.path)
data[0]
def test_gtzan(self):
data = GTZAN(self.path)
data[0]
def test_cmuarctic(self):
data = CMUARCTIC(self.path)
data[0]
if __name__ == "__main__":
unittest.main()
import os
import math
import unittest
import torch
import torchaudio
import math
import os
from .common_utils import AudioBackendScope, BACKENDS, BACKENDS_MP3, create_temp_assets_dir
from .common_utils import BACKENDS, BACKENDS_MP3, create_temp_assets_dir
class Test_LoadSave(unittest.TestCase):
......@@ -16,13 +18,13 @@ class Test_LoadSave(unittest.TestCase):
def test_1_save(self):
for backend in BACKENDS_MP3:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save(self.test_filepath, False)
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath, False)
for backend in BACKENDS:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save(self.test_filepath_wav, True)
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath_wav, True)
def _test_1_save(self, test_filepath, normalization):
# load signal
......@@ -67,8 +69,8 @@ class Test_LoadSave(unittest.TestCase):
def test_1_save_sine(self):
for backend in BACKENDS:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save_sine()
torchaudio.set_audio_backend(backend)
self._test_1_save_sine()
def _test_1_save_sine(self):
......@@ -100,13 +102,13 @@ class Test_LoadSave(unittest.TestCase):
def test_2_load(self):
for backend in BACKENDS_MP3:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load(self.test_filepath, 278756)
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath, 278756)
for backend in BACKENDS:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load(self.test_filepath_wav, 276858)
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath_wav, 276858)
def _test_2_load(self, test_filepath, length):
# check normal loading
......@@ -141,8 +143,8 @@ class Test_LoadSave(unittest.TestCase):
def test_2_load_nonormalization(self):
for backend in BACKENDS_MP3:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load_nonormalization(self.test_filepath, 278756)
torchaudio.set_audio_backend(backend)
self._test_2_load_nonormalization(self.test_filepath, 278756)
def _test_2_load_nonormalization(self, test_filepath, length):
......@@ -158,8 +160,8 @@ class Test_LoadSave(unittest.TestCase):
def test_3_load_and_save_is_identity(self):
for backend in BACKENDS:
with self.subTest():
with AudioBackendScope(backend):
self._test_3_load_and_save_is_identity()
torchaudio.set_audio_backend(backend)
self._test_3_load_and_save_is_identity()
def _test_3_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
......@@ -179,16 +181,15 @@ class Test_LoadSave(unittest.TestCase):
self._test_3_load_and_save_is_identity_across_backend("soundfile", "sox")
def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2):
with AudioBackendScope(backend1):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor1, sample_rate1 = torchaudio.load(input_path)
torchaudio.set_audio_backend(backend1)
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor1, sample_rate1 = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
torchaudio.save(output_path, tensor1, sample_rate1)
output_path = os.path.join(self.test_dirpath, 'test.wav')
torchaudio.save(output_path, tensor1, sample_rate1)
with AudioBackendScope(backend2):
tensor2, sample_rate2 = torchaudio.load(output_path)
torchaudio.set_audio_backend(backend2)
tensor2, sample_rate2 = torchaudio.load(output_path)
self.assertTrue(tensor1.allclose(tensor2))
self.assertEqual(sample_rate1, sample_rate2)
......@@ -197,8 +198,8 @@ class Test_LoadSave(unittest.TestCase):
def test_4_load_partial(self):
for backend in BACKENDS_MP3:
with self.subTest():
with AudioBackendScope(backend):
self._test_4_load_partial()
torchaudio.set_audio_backend(backend)
self._test_4_load_partial()
def _test_4_load_partial(self):
num_frames = 101
......@@ -239,8 +240,8 @@ class Test_LoadSave(unittest.TestCase):
def test_5_get_info(self):
for backend in BACKENDS:
with self.subTest():
with AudioBackendScope(backend):
self._test_5_get_info()
torchaudio.set_audio_backend(backend)
self._test_5_get_info()
def _test_5_get_info(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
......
......@@ -6,7 +6,7 @@ import torchaudio.kaldi_io as kio
from . import common_utils
class Test_KaldiIO(unittest.TestCase):
class Test_KaldiIO(common_utils.TorchaudioTestCase):
data1 = [[1, 2, 3], [11, 12, 13], [21, 22, 23]]
data2 = [[31, 32, 33], [41, 42, 43], [51, 52, 53]]
......
......@@ -4,7 +4,6 @@ import unittest
from distutils.version import StrictVersion
import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
from torchaudio._internal.module_utils import is_module_available
......@@ -22,7 +21,7 @@ from . import common_utils
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestFunctional(TestCase):
class TestFunctional(common_utils.TorchaudioTestCase):
"""Test suite for functions in `functional` module."""
def test_griffinlim(self):
# NOTE: This test is flaky without a fixed random seed
......@@ -157,9 +156,10 @@ def _load_audio_asset(*asset_paths, **kwargs):
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class TestTransforms(TestCase):
class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for functions in `transforms` module."""
def assert_compatibilities(self, n_fft, hop_length, power, n_mels, n_mfcc, sample_rate):
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset('sinewave.wav')
sound_librosa = sound.cpu().numpy().squeeze() # (64000)
......@@ -269,8 +269,7 @@ class TestTransforms(TestCase):
}
self.assert_compatibilities(**kwargs)
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope("sox")
@unittest.skipIf(not common_utils.BACKENDS_MP3, 'no backend to read mp3')
def test_MelScale(self):
"""MelScale transform is comparable to that of librosa"""
n_fft = 2048
......@@ -278,6 +277,7 @@ class TestTransforms(TestCase):
hop_length = n_fft // 4
# Prepare spectrogram input. We use torchaudio to compute one.
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset('whitenoise_1min.mp3')
sound = sound.mean(dim=0, keepdim=True)
spec_ta = F.spectrogram(
......@@ -300,6 +300,7 @@ class TestTransforms(TestCase):
hop_length = n_fft // 4
# Prepare mel spectrogram input. We use torchaudio to compute one.
common_utils.set_audio_backend('default')
sound, sample_rate = _load_audio_asset(
'steam-train-whistle-daniel_simon.wav', offset=2**10, num_frames=2**14)
sound = sound.mean(dim=0, keepdim=True)
......
import unittest
import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
from . import common_utils
from .common_utils import AudioBackendScope, BACKENDS
class TestFunctionalFiltering(TestCase):
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
@common_utils.skipIfNoSoxBackend
class TestFunctionalFiltering(common_utils.TorchaudioTestCase):
backend = 'sox'
def test_gain(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath)
......@@ -27,8 +26,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(waveform_gain, sox_gain_waveform, atol=1e-04, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_dither(self):
test_filepath = common_utils.get_asset_path('steam-train-whistle-daniel_simon.wav')
waveform, _ = torchaudio.load(test_filepath)
......@@ -49,8 +46,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_vctk_transform_pipeline(self):
test_filepath_vctk = common_utils.get_asset_path('VCTK-Corpus', 'wav48', 'p224', 'p224_002.wav')
wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)
......@@ -72,8 +67,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_lowpass(self):
"""
Test biquad lowpass filter, compare to SoX implementation
......@@ -92,8 +85,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_highpass(self):
"""
Test biquad highpass filter, compare to SoX implementation
......@@ -113,8 +104,6 @@ class TestFunctionalFiltering(TestCase):
# TBD - this fails at the 1e-4 level, debug why
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-3, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_allpass(self):
"""
Test biquad allpass filter, compare to SoX implementation
......@@ -134,8 +123,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_bandpass_with_csg(self):
"""
Test biquad bandpass filter, compare to SoX implementation
......@@ -156,8 +143,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_bandpass_without_csg(self):
"""
Test biquad bandpass filter, compare to SoX implementation
......@@ -178,8 +163,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_bandreject(self):
"""
Test biquad bandreject filter, compare to SoX implementation
......@@ -199,8 +182,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_band_with_noise(self):
"""
Test biquad band filter with noise mode, compare to SoX implementation
......@@ -221,8 +202,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_band_without_noise(self):
"""
Test biquad band filter without noise mode, compare to SoX implementation
......@@ -243,8 +222,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_treble(self):
"""
Test biquad treble filter, compare to SoX implementation
......@@ -265,8 +242,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_bass(self):
"""
Test biquad bass filter, compare to SoX implementation
......@@ -287,8 +262,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1.5e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_deemph(self):
"""
Test biquad deemph filter, compare to SoX implementation
......@@ -305,8 +278,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_riaa(self):
"""
Test biquad riaa filter, compare to SoX implementation
......@@ -323,8 +294,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_contrast(self):
"""
Test contrast effect, compare to SoX implementation
......@@ -341,8 +310,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_dcshift_with_limiter(self):
"""
Test dcshift effect, compare to SoX implementation
......@@ -360,8 +327,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_dcshift_without_limiter(self):
"""
Test dcshift effect, compare to SoX implementation
......@@ -378,8 +343,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_overdrive(self):
"""
Test overdrive effect, compare to SoX implementation
......@@ -397,8 +360,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_phaser_sine(self):
"""
Test phaser effect with sine moduldation, compare to SoX implementation
......@@ -419,8 +380,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_phaser_triangle(self):
"""
Test phaser effect with triangle modulation, compare to SoX implementation
......@@ -441,8 +400,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_triangle_linear(self):
"""
Test flanger effect with triangle modulation and linear interpolation, compare to SoX implementation
......@@ -465,8 +422,6 @@ class TestFunctionalFiltering(TestCase):
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_triangle_quad(self):
"""
Test flanger effect with triangle modulation and quadratic interpolation, compare to SoX implementation
......@@ -489,8 +444,6 @@ class TestFunctionalFiltering(TestCase):
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_sine_linear(self):
"""
Test flanger effect with sine modulation and linear interpolation, compare to SoX implementation
......@@ -513,8 +466,6 @@ class TestFunctionalFiltering(TestCase):
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_flanger_sine_quad(self):
"""
Test flanger effect with sine modulation and quadratic interpolation, compare to SoX implementation
......@@ -537,8 +488,6 @@ class TestFunctionalFiltering(TestCase):
torch.testing.assert_allclose(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_equalizer(self):
"""
Test biquad peaking equalizer filter, compare to SoX implementation
......@@ -559,8 +508,6 @@ class TestFunctionalFiltering(TestCase):
self.assertEqual(output_waveform, sox_output_waveform, atol=1e-4, rtol=1e-5)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_perf_biquad_filtering(self):
fn_sine = common_utils.get_asset_path('whitenoise.wav')
......
import math
import unittest
import torch
import torchaudio
import math
from . import common_utils
from .common_utils import AudioBackendScope, BACKENDS
class Test_SoxEffectsChain(unittest.TestCase):
@common_utils.skipIfNoSoxBackend
class Test_SoxEffectsChain(common_utils.TorchaudioTestCase):
backend = 'sox'
test_filepath = common_utils.get_asset_path("steam-train-whistle-daniel_simon.mp3")
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_single_channel(self):
fn_sine = common_utils.get_asset_path("sinewave.wav")
E = torchaudio.sox_effects.SoxEffectsChain()
......@@ -21,8 +22,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effects worked
# print(x.size())
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_rate_channels(self):
target_rate = 16000
target_channels = 1
......@@ -35,8 +34,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
self.assertEqual(sr, target_rate)
self.assertEqual(x.size(0), target_channels)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_lowpass_speed(self):
speed = .8
si, _ = torchaudio.info(self.test_filepath)
......@@ -49,8 +46,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effects worked
self.assertEqual(x.size(1), int((si.length / si.channels) / speed))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_ulaw_and_siginfo(self):
si_out = torchaudio.sox_signalinfo_t()
ei_out = torchaudio.sox_encodinginfo_t()
......@@ -68,8 +63,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
self.assertLess(x.unique().size(0), 2**8 + 1)
self.assertEqual(x.numel(), si_in.length)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_band_chorus(self):
si_in, ei_in = torchaudio.info(self.test_filepath)
ei_in.encoding = torchaudio.get_sox_encoding_t(1)
......@@ -84,8 +77,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
self.assertEqual(x.size(0), si_in.channels)
self.assertGreaterEqual(x.size(1) * x.size(0), si_in.length)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_synth(self):
si_in, ei_in = torchaudio.info(self.test_filepath)
len_in_seconds = si_in.length / si_in.channels / si_in.rate
......@@ -99,8 +90,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
self.assertEqual(x.size(0), si_in.channels)
self.assertEqual(si_in.length, x.size(0) * x.size(1))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_gain(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
......@@ -124,8 +113,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
E.clear_chain()
self.assertLess(x.abs().max().item(), 1.)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_tempo_or_speed(self):
tempo = .8
si, _ = torchaudio.info(self.test_filepath)
......@@ -159,8 +146,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked
self.assertAlmostEqual(x.size(1), math.ceil((si.length / si.channels) / speed), delta=1)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_trim(self):
x_orig, _ = torchaudio.load(self.test_filepath)
offset = "10000s"
......@@ -174,8 +159,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked
self.assertTrue(x.allclose(x_orig[:, offset_int:(offset_int + num_frames_int)], rtol=1e-4, atol=1e-4))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_silence_contrast(self):
si, _ = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain()
......@@ -186,8 +169,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked
self.assertLess(x.numel(), si.length)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_reverse(self):
x_orig, _ = torchaudio.load(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain()
......@@ -198,8 +179,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
rev_idx = torch.LongTensor(range(x_orig.size(1))[::-1])
self.assertTrue(x_orig.allclose(x_rev[:, rev_idx], rtol=1e-5, atol=2e-5))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_compand_fade(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
......@@ -209,8 +188,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked
# print(x.size())
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_biquad_delay(self):
si, _ = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain()
......@@ -222,8 +199,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked
self.assertTrue(x.size(1) == (si.length / si.channels) + 15000)
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_invalid_effect_name(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
......@@ -231,8 +206,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
with self.assertRaises(LookupError):
E.append_effect_to_chain("special", [""])
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_unimplemented_effect(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
......@@ -240,8 +213,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
with self.assertRaises(NotImplementedError):
E.append_effect_to_chain("spectrogram", [""])
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_invalid_effect_options(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
......@@ -250,8 +221,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
with self.assertRaises(RuntimeError):
E.sox_build_flow_effects()
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_fade(self):
x_orig, _ = torchaudio.load(self.test_filepath)
fade_in_len = 44100
......@@ -268,8 +237,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked
self.assertTrue(x.allclose(fade(x_orig), rtol=1e-4, atol=1e-4))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_vol(self):
x_orig, _ = torchaudio.load(self.test_filepath)
......@@ -284,8 +251,6 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked
self.assertTrue(x.allclose(z, rtol=1e-4, atol=1e-4))
@unittest.skipIf("sox" not in BACKENDS, "sox not available")
@AudioBackendScope("sox")
def test_vad(self):
sample_files = [
common_utils.get_asset_path("vad-go-stereo-44100.wav"),
......
......@@ -2,7 +2,6 @@ import math
import unittest
import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.transforms as transforms
import torchaudio.functional as F
......@@ -10,7 +9,8 @@ import torchaudio.functional as F
from . import common_utils
class Tester(TestCase):
class Tester(common_utils.TorchaudioTestCase):
backend = 'default'
# create a sinewave signal for testing
sample_rate = 16000
......
......@@ -4,21 +4,21 @@ from . import common_utils
from .torchscript_consistency_impl import Functional, Transforms
class TestFunctionalFloat32(Functional, common_utils.TestCase):
class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestFunctionalFloat64(Functional, common_utils.TestCase):
class TestFunctionalFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
class TestTransformsFloat32(Transforms, common_utils.TestCase):
class TestTransformsFloat32(Transforms, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cpu')
class TestTransformsFloat64(Transforms, common_utils.TestCase):
class TestTransformsFloat64(Transforms, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cpu')
......@@ -5,24 +5,24 @@ from .torchscript_consistency_impl import Functional, Transforms
@common_utils.skipIfNoCuda
class TestFunctionalFloat32(Functional, common_utils.TestCase):
class TestFunctionalFloat32(Functional, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestFunctionalFloat64(Functional, common_utils.TestCase):
class TestFunctionalFloat64(Functional, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestTransformsFloat32(Transforms, common_utils.TestCase):
class TestTransformsFloat32(Transforms, common_utils.PytorchTestCase):
dtype = torch.float32
device = torch.device('cuda')
@common_utils.skipIfNoCuda
class TestTransformsFloat64(Transforms, common_utils.TestCase):
class TestTransformsFloat64(Transforms, common_utils.PytorchTestCase):
dtype = torch.float64
device = torch.device('cuda')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment