Unverified Commit a20da5e3 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor test utilities (#756)

parent 6b159054
from .data_utils import (
get_asset_path,
get_whitenoise,
get_sinusoid,
)
from .backend_utils import (
set_audio_backend,
BACKENDS,
BACKENDS_MP3,
)
from .test_case_utils import (
TempDirMixin,
TestBaseMixin,
PytorchTestCase,
TorchaudioTestCase,
skipIfNoCuda,
skipIfNoExec,
skipIfNoModule,
skipIfNoExtension,
skipIfNoSoxBackend,
)
from .wav_utils import (
get_wav_data,
normalize_wav,
load_wav,
save_wav,
)
from .parameterized_utils import (
load_params,
)
from . import sox_utils
import unittest
import torchaudio
from .import data_utils
BACKENDS = torchaudio.list_audio_backends()
def _filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3
test_filepath = data_utils.get_asset_path('steam-train-whistle-daniel_simon.mp3')
def supports_mp3(backend):
torchaudio.set_audio_backend(backend)
try:
torchaudio.load(test_filepath)
return True
except (RuntimeError, ImportError):
return False
return [backend for backend in backends if supports_mp3(backend)]
BACKENDS_MP3 = _filter_backends_with_mp3(BACKENDS)
def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend
torchaudio.set_audio_backend(be)
import os import os.path
import shutil
import tempfile
import unittest
from typing import Union from typing import Union
from shutil import copytree
import torch import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import is_module_available
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio.list_audio_backends() _TEST_DIR_PATH = os.path.realpath(
os.path.join(os.path.dirname(__file__), '..'))
def get_asset_path(*paths): def get_asset_path(*paths):
...@@ -19,138 +13,6 @@ def get_asset_path(*paths): ...@@ -19,138 +13,6 @@ def get_asset_path(*paths):
return os.path.join(_TEST_DIR_PATH, 'assets', *paths) return os.path.join(_TEST_DIR_PATH, 'assets', *paths)
def create_temp_assets_dir():
"""
Creates a temporary directory and moves all files from test/assets there.
Returns a Tuple[string, TemporaryDirectory] which is the folder path
and object.
"""
tmp_dir = tempfile.TemporaryDirectory()
copytree(os.path.join(_TEST_DIR_PATH, "assets"),
os.path.join(tmp_dir.name, "assets"))
return tmp_dir.name, tmp_dir
def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
""" Generates random tensors given a seed and size
https://en.wikipedia.org/wiki/Linear_congruential_generator
X_{n + 1} = (a * X_n + c) % m
Using Borland C/C++ values
The tensor will have values between [0,1)
Inputs:
seed (int): an int
size (Tuple[int]): the size of the output tensor
a (int): the multiplier constant to the generator
c (int): the additive constant to the generator
m (int): the modulus constant to the generator
"""
num_elements = 1
for s in size:
num_elements *= s
arr = [(a * seed + c) % m]
for i in range(num_elements - 1):
arr.append((a * arr[i] + c) % m)
return torch.tensor(arr).float().view(size) / m
def filter_backends_with_mp3(backends):
# Filter out backends that do not support mp3
test_filepath = get_asset_path('steam-train-whistle-daniel_simon.mp3')
def supports_mp3(backend):
torchaudio.set_audio_backend(backend)
try:
torchaudio.load(test_filepath)
return True
except (RuntimeError, ImportError):
return False
return [backend for backend in backends if supports_mp3(backend)]
BACKENDS_MP3 = filter_backends_with_mp3(BACKENDS)
def set_audio_backend(backend):
"""Allow additional backend value, 'default'"""
if backend == 'default':
if 'sox' in BACKENDS:
be = 'sox'
elif 'soundfile' in BACKENDS:
be = 'soundfile'
else:
raise unittest.SkipTest('No default backend available')
else:
be = backend
torchaudio.set_audio_backend(be)
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
base_temp_dir = None
temp_dir = None
@classmethod
def setUpClass(cls):
super().setUpClass()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
cls.base_temp_dir = os.environ[key]
else:
cls.temp_dir_ = tempfile.TemporaryDirectory()
cls.base_temp_dir = cls.temp_dir_.name
@classmethod
def tearDownClass(cls):
super().tearDownClass()
if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
cls.temp_dir_.cleanup()
def setUp(self):
self.temp_dir = os.path.join(self.base_temp_dir, self.id())
def get_temp_path(self, *paths):
path = os.path.join(self.temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
device = None
backend = None
def setUp(self):
super().setUp()
set_audio_backend(self.backend)
class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass
def skipIfNoExec(cmd):
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')
def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')
skipIfNoSoxBackend = unittest.skipIf('sox' not in BACKENDS, 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')
def get_whitenoise( def get_whitenoise(
*, *,
sample_rate: int = 16000, sample_rate: int = 16000,
......
import json
from parameterized import param
from .data_utils import get_asset_path
def load_params(*paths):
with open(get_asset_path(*paths), 'r') as file:
return [param(json.loads(line)) for line in file]
import shutil
import os.path
import tempfile
import unittest
import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import is_module_available
from .backend_utils import set_audio_backend
class TempDirMixin:
"""Mixin to provide easy access to temp dir"""
temp_dir_ = None
base_temp_dir = None
temp_dir = None
@classmethod
def setUpClass(cls):
super().setUpClass()
# If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory.
# this is handy for debugging.
key = 'TORCHAUDIO_TEST_TEMP_DIR'
if key in os.environ:
cls.base_temp_dir = os.environ[key]
else:
cls.temp_dir_ = tempfile.TemporaryDirectory()
cls.base_temp_dir = cls.temp_dir_.name
@classmethod
def tearDownClass(cls):
super().tearDownClass()
if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory):
cls.temp_dir_.cleanup()
def setUp(self):
super().setUp()
self.temp_dir = os.path.join(self.base_temp_dir, self.id())
def get_temp_path(self, *paths):
path = os.path.join(self.temp_dir, *paths)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
device = None
backend = None
def setUp(self):
super().setUp()
set_audio_backend(self.backend)
class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
pass
def skipIfNoExec(cmd):
return unittest.skipIf(shutil.which(cmd) is None, f'`{cmd}` is not available')
def skipIfNoModule(module, display_name=None):
display_name = display_name or module
return unittest.skipIf(not is_module_available(module), f'"{display_name}" is not available')
skipIfNoSoxBackend = unittest.skipIf(
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoExtension = skipIfNoModule('torchaudio._torchaudio', 'torchaudio C++ extension')
from typing import Optional
import torch
import scipy.io.wavfile
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype == torch.float32:
pass
elif tensor.dtype == torch.int32:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 2147483647.
tensor[tensor < 0] /= 2147483648.
elif tensor.dtype == torch.int16:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 32767.
tensor[tensor < 0] /= 32768.
elif tensor.dtype == torch.uint8:
tensor = tensor.to(torch.float32) - 128
tensor[tensor > 0] /= 127.
tensor[tensor < 0] /= 128.
return tensor
def get_wav_data(
dtype: str,
num_channels: int,
*,
num_frames: Optional[int] = None,
normalize: bool = True,
channels_first: bool = True,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(torch, dtype)
if num_frames is None:
if dtype == 'uint8':
num_frames = 256
else:
num_frames = 1 << 16
if dtype == 'uint8':
base = torch.linspace(0, 255, num_frames, dtype=dtype_)
if dtype == 'float32':
base = torch.linspace(-1., 1., num_frames, dtype=dtype_)
if dtype == 'int32':
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
if dtype == 'int16':
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_)
data = base.repeat([num_channels, 1])
if not channels_first:
data = data.transpose(1, 0)
if normalize:
data = normalize_wav(data)
return data
def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor:
"""Load wav file without torchaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = torch.from_numpy(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
if channels_first:
data = data.transpose(1, 0)
return data, sample_rate
def save_wav(path, data, sample_rate, channels_first=True):
"""Save wav file without torchaudio"""
if channels_first:
data = data.transpose(1, 0)
scipy.io.wavfile.write(path, sample_rate, data.numpy())
...@@ -10,6 +10,31 @@ from . import common_utils ...@@ -10,6 +10,31 @@ from . import common_utils
from .functional_impl import Lfilter from .functional_impl import Lfilter
def random_float_tensor(seed, size, a=22695477, c=1, m=2 ** 32):
""" Generates random tensors given a seed and size
https://en.wikipedia.org/wiki/Linear_congruential_generator
X_{n + 1} = (a * X_n + c) % m
Using Borland C/C++ values
The tensor will have values between [0,1)
Inputs:
seed (int): an int
size (Tuple[int]): the size of the output tensor
a (int): the multiplier constant to the generator
c (int): the additive constant to the generator
m (int): the modulus constant to the generator
"""
num_elements = 1
for s in size:
num_elements *= s
arr = [(a * seed + c) % m]
for i in range(num_elements - 1):
arr.append((a * arr[i] + c) % m)
return torch.tensor(arr).float().view(size) / m
class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase): class TestLFilterFloat32(Lfilter, common_utils.PytorchTestCase):
dtype = torch.float32 dtype = torch.float32
device = torch.device('cpu') device = torch.device('cpu')
...@@ -49,7 +74,7 @@ def _test_istft_is_inverse_of_stft(kwargs): ...@@ -49,7 +74,7 @@ def _test_istft_is_inverse_of_stft(kwargs):
for data_size in [(2, 20), (3, 15), (4, 10)]: for data_size in [(2, 20), (3, 15), (4, 10)]:
for i in range(100): for i in range(100):
sound = common_utils.random_float_tensor(i, data_size) sound = random_float_tensor(i, data_size)
stft = torch.stft(sound, **kwargs) stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
...@@ -211,8 +236,8 @@ class TestIstft(common_utils.TorchaudioTestCase): ...@@ -211,8 +236,8 @@ class TestIstft(common_utils.TorchaudioTestCase):
def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8): def _test_linearity_of_istft(self, data_size, kwargs, atol=1e-6, rtol=1e-8):
for i in range(self.number_of_trials): for i in range(self.number_of_trials):
tensor1 = common_utils.random_float_tensor(i, data_size) tensor1 = random_float_tensor(i, data_size)
tensor2 = common_utils.random_float_tensor(i * 2, data_size) tensor2 = random_float_tensor(i * 2, data_size)
a, b = torch.rand(2) a, b = torch.rand(2)
istft1 = torchaudio.functional.istft(tensor1, **kwargs) istft1 = torchaudio.functional.istft(tensor1, **kwargs)
istft2 = torchaudio.functional.istft(tensor2, **kwargs) istft2 = torchaudio.functional.istft(tensor2, **kwargs)
......
"""Test suites for checking numerical compatibility against Kaldi""" """Test suites for checking numerical compatibility against Kaldi"""
import json
import subprocess import subprocess
import kaldi_io import kaldi_io
...@@ -8,7 +7,8 @@ import torchaudio.functional as F ...@@ -8,7 +7,8 @@ import torchaudio.functional as F
import torchaudio.compliance.kaldi import torchaudio.compliance.kaldi
from . import common_utils from . import common_utils
from parameterized import parameterized, param from .common_utils import load_params
from parameterized import parameterized
def _convert_args(**kwargs): def _convert_args(**kwargs):
...@@ -43,11 +43,6 @@ def _run_kaldi(command, input_type, input_value): ...@@ -43,11 +43,6 @@ def _run_kaldi(command, input_type, input_value):
return torch.from_numpy(result.copy()) # copy supresses some torch warning return torch.from_numpy(result.copy()) # copy supresses some torch warning
def _load_params(path):
with open(path, 'r') as file:
return [param(json.loads(line)) for line in file]
class Kaldi(common_utils.TestBaseMixin): class Kaldi(common_utils.TestBaseMixin):
backend = 'sox' backend = 'sox'
...@@ -71,7 +66,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -71,7 +66,7 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result = _run_kaldi(command, 'ark', tensor) kaldi_result = _run_kaldi(command, 'ark', tensor)
self.assert_equal(result, expected=kaldi_result) self.assert_equal(result, expected=kaldi_result)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_fbank_args.json'))) @parameterized.expand(load_params('kaldi_test_fbank_args.json'))
@common_utils.skipIfNoExec('compute-fbank-feats') @common_utils.skipIfNoExec('compute-fbank-feats')
def test_fbank(self, kwargs): def test_fbank(self, kwargs):
"""fbank should be numerically compatible with compute-fbank-feats""" """fbank should be numerically compatible with compute-fbank-feats"""
...@@ -82,7 +77,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -82,7 +77,7 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_spectrogram_args.json'))) @parameterized.expand(load_params('kaldi_test_spectrogram_args.json'))
@common_utils.skipIfNoExec('compute-spectrogram-feats') @common_utils.skipIfNoExec('compute-spectrogram-feats')
def test_spectrogram(self, kwargs): def test_spectrogram(self, kwargs):
"""spectrogram should be numerically compatible with compute-spectrogram-feats""" """spectrogram should be numerically compatible with compute-spectrogram-feats"""
...@@ -93,7 +88,7 @@ class Kaldi(common_utils.TestBaseMixin): ...@@ -93,7 +88,7 @@ class Kaldi(common_utils.TestBaseMixin):
kaldi_result = _run_kaldi(command, 'scp', wave_file) kaldi_result = _run_kaldi(command, 'scp', wave_file)
self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8)
@parameterized.expand(_load_params(common_utils.get_asset_path('kaldi_test_mfcc_args.json'))) @parameterized.expand(load_params('kaldi_test_mfcc_args.json'))
@common_utils.skipIfNoExec('compute-mfcc-feats') @common_utils.skipIfNoExec('compute-mfcc-feats')
def test_mfcc(self, kwargs): def test_mfcc(self, kwargs):
"""mfcc should be numerically compatible with compute-mfcc-feats""" """mfcc should be numerically compatible with compute-mfcc-feats"""
......
from typing import Optional def name_func(func, _, params):
return f'{func.__name__}_{"_".join(str(arg) for arg in params.args)}'
import torch
import scipy.io.wavfile
def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype == torch.float32:
pass
elif tensor.dtype == torch.int32:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 2147483647.
tensor[tensor < 0] /= 2147483648.
elif tensor.dtype == torch.int16:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 32767.
tensor[tensor < 0] /= 32768.
elif tensor.dtype == torch.uint8:
tensor = tensor.to(torch.float32) - 128
tensor[tensor > 0] /= 127.
tensor[tensor < 0] /= 128.
return tensor
def get_wav_data(
dtype: str,
num_channels: int,
*,
num_frames: Optional[int] = None,
normalize: bool = True,
channels_first: bool = True,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(torch, dtype)
if num_frames is None:
if dtype == 'uint8':
num_frames = 256
else:
num_frames = 1 << 16
if dtype == 'uint8':
base = torch.linspace(0, 255, num_frames, dtype=dtype_)
if dtype == 'float32':
base = torch.linspace(-1., 1., num_frames, dtype=dtype_)
if dtype == 'int32':
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
if dtype == 'int16':
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_)
data = base.repeat([num_channels, 1])
if not channels_first:
data = data.transpose(1, 0)
if normalize:
data = normalize_wav(data)
return data
def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor:
"""Load wav file without torchaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = torch.from_numpy(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
if channels_first:
data = data.transpose(1, 0)
return data, sample_rate
def save_wav(path, data, sample_rate, channels_first=True):
"""Save wav file without torchaudio"""
if channels_first:
data = data.transpose(1, 0)
scipy.io.wavfile.write(path, sample_rate, data.numpy())
...@@ -8,13 +8,13 @@ from ..common_utils import ( ...@@ -8,13 +8,13 @@ from ..common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoExtension,
) sox_utils,
from .common import (
get_test_name,
get_wav_data, get_wav_data,
save_wav, save_wav,
) )
from . import sox_utils from .common import (
name_func,
)
@skipIfNoExec('sox') @skipIfNoExec('sox')
...@@ -24,7 +24,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -24,7 +24,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
)), name_func=get_test_name) )), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file correctly""" """`sox_io_backend.info` can check wav file correctly"""
duration = 1 duration = 1
...@@ -40,7 +40,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -40,7 +40,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
[4, 8, 16, 32], [4, 8, 16, 32],
)), name_func=get_test_name) )), name_func=name_func)
def test_wav_multiple_channels(self, dtype, sample_rate, num_channels): def test_wav_multiple_channels(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly""" """`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1 duration = 1
...@@ -56,7 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -56,7 +56,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
[96, 128, 160, 192, 224, 256, 320], [96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name) )), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.info` can check mp3 file correctly""" """`sox_io_backend.info` can check mp3 file correctly"""
duration = 1 duration = 1
...@@ -75,7 +75,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -75,7 +75,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
list(range(9)), list(range(9)),
)), name_func=get_test_name) )), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.info` can check flac file correctly""" """`sox_io_backend.info` can check flac file correctly"""
duration = 1 duration = 1
...@@ -93,7 +93,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -93,7 +93,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10], [-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name) )), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.info` can check vorbis file correctly""" """`sox_io_backend.info` can check vorbis file correctly"""
duration = 1 duration = 1
......
...@@ -8,14 +8,14 @@ from ..common_utils import ( ...@@ -8,14 +8,14 @@ from ..common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data, get_wav_data,
load_wav, load_wav,
save_wav, save_wav,
sox_utils,
)
from .common import (
name_func,
) )
from . import sox_utils
class LoadTestBase(TempDirMixin, PytorchTestCase): class LoadTestBase(TempDirMixin, PytorchTestCase):
...@@ -129,7 +129,7 @@ class TestLoad(LoadTestBase): ...@@ -129,7 +129,7 @@ class TestLoad(LoadTestBase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
[False, True], [False, True],
)), name_func=get_test_name) )), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels, normalize): def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly.""" """`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1) self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
...@@ -139,7 +139,7 @@ class TestLoad(LoadTestBase): ...@@ -139,7 +139,7 @@ class TestLoad(LoadTestBase):
[16000], [16000],
[2], [2],
[False], [False],
)), name_func=get_test_name) )), name_func=name_func)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize): def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load large wav file correctly.""" """`sox_io_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
...@@ -148,7 +148,7 @@ class TestLoad(LoadTestBase): ...@@ -148,7 +148,7 @@ class TestLoad(LoadTestBase):
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32], [4, 8, 16, 32],
)), name_func=get_test_name) )), name_func=name_func)
def test_multiple_channels(self, dtype, num_channels): def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.load` can load wav file with more than 2 channels.""" """`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000 sample_rate = 8000
...@@ -159,7 +159,7 @@ class TestLoad(LoadTestBase): ...@@ -159,7 +159,7 @@ class TestLoad(LoadTestBase):
[8000, 16000, 44100], [8000, 16000, 44100],
[1, 2], [1, 2],
[96, 128, 160, 192, 224, 256, 320], [96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name) )), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly.""" """`sox_io_backend.load` can load mp3 format correctly."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
...@@ -168,7 +168,7 @@ class TestLoad(LoadTestBase): ...@@ -168,7 +168,7 @@ class TestLoad(LoadTestBase):
[16000], [16000],
[2], [2],
[128], [128],
)), name_func=get_test_name) )), name_func=name_func)
def test_mp3_large(self, sample_rate, num_channels, bit_rate): def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly.""" """`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
...@@ -178,7 +178,7 @@ class TestLoad(LoadTestBase): ...@@ -178,7 +178,7 @@ class TestLoad(LoadTestBase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
list(range(9)), list(range(9)),
)), name_func=get_test_name) )), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly.""" """`sox_io_backend.load` can load flac format correctly."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
...@@ -187,7 +187,7 @@ class TestLoad(LoadTestBase): ...@@ -187,7 +187,7 @@ class TestLoad(LoadTestBase):
[16000], [16000],
[2], [2],
[0], [0],
)), name_func=get_test_name) )), name_func=name_func)
def test_flac_large(self, sample_rate, num_channels, compression_level): def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly.""" """`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
...@@ -197,7 +197,7 @@ class TestLoad(LoadTestBase): ...@@ -197,7 +197,7 @@ class TestLoad(LoadTestBase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10], [-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name) )), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly.""" """`sox_io_backend.load` can load vorbis format correctly."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1) self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)
...@@ -206,7 +206,7 @@ class TestLoad(LoadTestBase): ...@@ -206,7 +206,7 @@ class TestLoad(LoadTestBase):
[16000], [16000],
[2], [2],
[10], [10],
)), name_func=get_test_name) )), name_func=name_func)
def test_vorbis_large(self, sample_rate, num_channels, quality_level): def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly.""" """`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
...@@ -230,14 +230,14 @@ class TestLoadParams(TempDirMixin, PytorchTestCase): ...@@ -230,14 +230,14 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000], [0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000], [-1, 1, 10, 100, 1000],
)), name_func=get_test_name) )), name_func=name_func)
def test_frame(self, frame_offset, num_frames): def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data""" """num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames) found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end]) self.assertEqual(found, self.original[:, frame_offset:frame_end])
@parameterized.expand([(True, ), (False, )], name_func=get_test_name) @parameterized.expand([(True, ), (False, )], name_func=name_func)
def test_channels_first(self, channels_first): def test_channels_first(self, channels_first):
"""channels_first swaps axes""" """channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first) found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
......
...@@ -8,10 +8,10 @@ from ..common_utils import ( ...@@ -8,10 +8,10 @@ from ..common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoExtension,
get_wav_data,
) )
from .common import ( from .common import (
get_test_name, name_func,
get_wav_data,
) )
...@@ -23,7 +23,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase): ...@@ -23,7 +23,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
)), name_func=get_test_name) )), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""save/load round trip should not degrade data for wav formats""" """save/load round trip should not degrade data for wav formats"""
original = get_wav_data(dtype, num_channels, normalize=False) original = get_wav_data(dtype, num_channels, normalize=False)
...@@ -39,7 +39,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase): ...@@ -39,7 +39,7 @@ class TestRoundTripIO(TempDirMixin, PytorchTestCase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
list(range(9)), list(range(9)),
)), name_func=get_test_name) )), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""save/load round trip should not degrade data for flac formats""" """save/load round trip should not degrade data for flac formats"""
original = get_wav_data('float32', num_channels) original = get_wav_data('float32', num_channels)
......
...@@ -8,14 +8,14 @@ from ..common_utils import ( ...@@ -8,14 +8,14 @@ from ..common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data, get_wav_data,
load_wav, load_wav,
save_wav, save_wav,
sox_utils,
)
from .common import (
name_func,
) )
from . import sox_utils
class SaveTestBase(TempDirMixin, PytorchTestCase): class SaveTestBase(TempDirMixin, PytorchTestCase):
...@@ -176,7 +176,7 @@ class TestSave(SaveTestBase): ...@@ -176,7 +176,7 @@ class TestSave(SaveTestBase):
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
)), name_func=get_test_name) )), name_func=name_func)
def test_wav(self, dtype, sample_rate, num_channels): def test_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save wav format.""" """`sox_io_backend.save` can save wav format."""
self.assert_wav(dtype, sample_rate, num_channels, num_frames=None) self.assert_wav(dtype, sample_rate, num_channels, num_frames=None)
...@@ -185,7 +185,7 @@ class TestSave(SaveTestBase): ...@@ -185,7 +185,7 @@ class TestSave(SaveTestBase):
['float32'], ['float32'],
[16000], [16000],
[2], [2],
)), name_func=get_test_name) )), name_func=name_func)
def test_wav_large(self, dtype, sample_rate, num_channels): def test_wav_large(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.save` can save large wav file.""" """`sox_io_backend.save` can save large wav file."""
two_hours = 2 * 60 * 60 * sample_rate two_hours = 2 * 60 * 60 * sample_rate
...@@ -194,7 +194,7 @@ class TestSave(SaveTestBase): ...@@ -194,7 +194,7 @@ class TestSave(SaveTestBase):
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32], [4, 8, 16, 32],
)), name_func=get_test_name) )), name_func=name_func)
def test_multiple_channels(self, dtype, num_channels): def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.save` can save wav with more than 2 channels.""" """`sox_io_backend.save` can save wav with more than 2 channels."""
sample_rate = 8000 sample_rate = 8000
...@@ -204,7 +204,7 @@ class TestSave(SaveTestBase): ...@@ -204,7 +204,7 @@ class TestSave(SaveTestBase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name) )), name_func=name_func)
def test_mp3(self, sample_rate, num_channels, bit_rate): def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.save` can save mp3 format.""" """`sox_io_backend.save` can save mp3 format."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1) self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
...@@ -213,7 +213,7 @@ class TestSave(SaveTestBase): ...@@ -213,7 +213,7 @@ class TestSave(SaveTestBase):
[16000], [16000],
[2], [2],
[128], [128],
)), name_func=get_test_name) )), name_func=name_func)
def test_mp3_large(self, sample_rate, num_channels, bit_rate): def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.save` can save large mp3 file.""" """`sox_io_backend.save` can save large mp3 file."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
...@@ -223,7 +223,7 @@ class TestSave(SaveTestBase): ...@@ -223,7 +223,7 @@ class TestSave(SaveTestBase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
list(range(9)), list(range(9)),
)), name_func=get_test_name) )), name_func=name_func)
def test_flac(self, sample_rate, num_channels, compression_level): def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save flac format.""" """`sox_io_backend.save` can save flac format."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1) self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
...@@ -232,7 +232,7 @@ class TestSave(SaveTestBase): ...@@ -232,7 +232,7 @@ class TestSave(SaveTestBase):
[16000], [16000],
[2], [2],
[0], [0],
)), name_func=get_test_name) )), name_func=name_func)
def test_flac_large(self, sample_rate, num_channels, compression_level): def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.save` can save large flac file.""" """`sox_io_backend.save` can save large flac file."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
...@@ -242,7 +242,7 @@ class TestSave(SaveTestBase): ...@@ -242,7 +242,7 @@ class TestSave(SaveTestBase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10], [-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name) )), name_func=name_func)
def test_vorbis(self, sample_rate, num_channels, quality_level): def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save vorbis format.""" """`sox_io_backend.save` can save vorbis format."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20) self.assert_vorbis(sample_rate, num_channels, quality_level, duration=20)
...@@ -255,7 +255,7 @@ class TestSave(SaveTestBase): ...@@ -255,7 +255,7 @@ class TestSave(SaveTestBase):
[16000], [16000],
[2], [2],
[10], [10],
)), name_func=get_test_name) )), name_func=name_func)
def test_vorbis_large(self, sample_rate, num_channels, quality_level): def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.save` can save large vorbis file correctly.""" """`sox_io_backend.save` can save large vorbis file correctly."""
two_hours = 2 * 60 * 60 two_hours = 2 * 60 * 60
...@@ -267,7 +267,7 @@ class TestSave(SaveTestBase): ...@@ -267,7 +267,7 @@ class TestSave(SaveTestBase):
@skipIfNoExtension @skipIfNoExtension
class TestSaveParams(TempDirMixin, PytorchTestCase): class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`""" """Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=get_test_name) @parameterized.expand([(True, ), (False, )], name_func=name_func)
def test_channels_first(self, channels_first): def test_channels_first(self, channels_first):
"""channels_first swaps axes""" """channels_first swaps axes"""
path = self.get_temp_path('data.wav') path = self.get_temp_path('data.wav')
...@@ -280,7 +280,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase): ...@@ -280,7 +280,7 @@ class TestSaveParams(TempDirMixin, PytorchTestCase):
@parameterized.expand([ @parameterized.expand([
'float32', 'int32', 'int16', 'uint8' 'float32', 'int32', 'int16', 'uint8'
], name_func=get_test_name) ], name_func=name_func)
def test_noncontiguous(self, dtype): def test_noncontiguous(self, dtype):
"""Noncontiguous tensors are saved correctly""" """Noncontiguous tensors are saved correctly"""
path = self.get_temp_path('data.wav') path = self.get_temp_path('data.wav')
......
...@@ -10,14 +10,14 @@ from ..common_utils import ( ...@@ -10,14 +10,14 @@ from ..common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data, get_wav_data,
save_wav, save_wav,
load_wav, load_wav,
sox_utils,
)
from .common import (
name_func,
) )
from . import sox_utils
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
...@@ -47,7 +47,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -47,7 +47,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
)), name_func=get_test_name) )), name_func=name_func)
def test_info_wav(self, dtype, sample_rate, num_channels): def test_info_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` is torchscript-able and returns the same result""" """`sox_io_backend.info` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
...@@ -71,7 +71,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -71,7 +71,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[1, 2], [1, 2],
[False, True], [False, True],
[False, True], [False, True],
)), name_func=get_test_name) )), name_func=name_func)
def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first): def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`sox_io_backend.load` is torchscript-able and returns the same result""" """`sox_io_backend.load` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav') audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
...@@ -94,7 +94,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -94,7 +94,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
)), name_func=get_test_name) )), name_func=name_func)
def test_save_wav(self, dtype, sample_rate, num_channels): def test_save_wav(self, dtype, sample_rate, num_channels):
script_path = self.get_temp_path('save_func.zip') script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path) torch.jit.script(py_save_func).save(script_path)
...@@ -119,7 +119,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -119,7 +119,7 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
list(range(9)), list(range(9)),
)), name_func=get_test_name) )), name_func=name_func)
def test_save_flac(self, sample_rate, num_channels, compression_level): def test_save_flac(self, sample_rate, num_channels, compression_level):
script_path = self.get_temp_path('save_func.zip') script_path = self.get_temp_path('save_func.zip')
torch.jit.script(py_save_func).save(script_path) torch.jit.script(py_save_func).save(script_path)
......
import os import os
import math import math
import shutil
import tempfile
import unittest import unittest
import torch import torch
import torchaudio import torchaudio
from .common_utils import BACKENDS, BACKENDS_MP3, create_temp_assets_dir from .common_utils import BACKENDS, BACKENDS_MP3, get_asset_path
def create_temp_assets_dir():
"""
Creates a temporary directory and moves all files from test/assets there.
Returns a Tuple[string, TemporaryDirectory] which is the folder path
and object.
"""
tmp_dir = tempfile.TemporaryDirectory()
shutil.copytree(get_asset_path(), os.path.join(tmp_dir.name, "assets"))
return tmp_dir.name, tmp_dir
class Test_LoadSave(unittest.TestCase): class Test_LoadSave(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment