import os import tempfile from typing import Type, Iterable from contextlib import contextmanager from shutil import copytree import torch import torchaudio import pytest _TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) BACKENDS = torchaudio._backend._audio_backends def get_asset_path(*paths): """Return full path of a test asset""" return os.path.join(_TEST_DIR_PATH, 'assets', *paths) def create_temp_assets_dir(): """ Creates a temporary directory and moves all files from test/assets there. Returns a Tuple[string, TemporaryDirectory] which is the folder path and object. """ tmp_dir = tempfile.TemporaryDirectory() copytree(os.path.join(_TEST_DIR_PATH, "assets"), os.path.join(tmp_dir.name, "assets")) return tmp_dir.name, tmp_dir def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32): """ Generates random tensors given a seed and size https://en.wikipedia.org/wiki/Linear_congruential_generator X_{n + 1} = (a * X_n + c) % m Using Borland C/C++ values The tensor will have values between [0,1) Inputs: seed (int): an int size (Tuple[int]): the size of the output tensor a (int): the multiplier constant to the generator c (int): the additive constant to the generator m (int): the modulus constant to the generator """ num_elements = 1 for s in size: num_elements *= s arr = [(a * seed + c) % m] for i in range(num_elements - 1): arr.append((a * arr[i] + c) % m) return torch.tensor(arr).float().view(size) / m @contextmanager def AudioBackendScope(new_backend): previous_backend = torchaudio.get_audio_backend() try: torchaudio.set_audio_backend(new_backend) yield finally: torchaudio.set_audio_backend(previous_backend) def filter_backends_with_mp3(backends): # Filter out backends that do not support mp3 test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3') def supports_mp3(backend): try: with AudioBackendScope(backend): torchaudio.load(test_filepath) return True except (RuntimeError, ImportError): return False return [backend for backend in backends if supports_mp3(backend)] BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS) class TestBaseMixin: dtype = None device = None def define_test_suite(testbase: Type[TestBaseMixin], dtype: str, device: str): if dtype not in ['float32', 'float64']: raise NotImplementedError(f'Unexpected dtype: {dtype}') if device not in ['cpu', 'cuda']: raise NotImplementedError(f'Unexpected device: {device}') name = f'Test{testbase.__name__}_{device.upper()}_{dtype.capitalize()}' attrs = {'dtype': getattr(torch, dtype), 'device': torch.device(device)} testsuite = type(name, (testbase,), attrs) if device == 'cuda': testsuite = pytest.mark.skipif( not torch.cuda.is_available(), reason='CUDA not available')(testsuite) return testsuite def define_test_suites( scope: dict, testbases: Iterable[Type[TestBaseMixin]], dtypes: Iterable[str] = ('float32', 'float64'), devices: Iterable[str] = ('cpu', 'cuda'), ): for suite in testbases: for device in devices: for dtype in dtypes: t = define_test_suite(suite, dtype, device) scope[t.__name__] = t