Unverified Commit ecfed4d9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Make sox selective (#1338)

parent 98d0d593
......@@ -16,8 +16,8 @@ from torchaudio_unittest.common_utils import (
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
skipIfNoSox,
get_asset_path,
get_wav_data,
save_wav,
......@@ -33,7 +33,7 @@ if _mod_utils.is_module_available("requests"):
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
......@@ -253,7 +253,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.encoding == "PCM_S"
@skipIfNoExtension
@skipIfNoSox
class TestInfoOpus(PytorchTestCase):
@parameterized.expand(list(itertools.product(
['96k'],
......@@ -271,7 +271,7 @@ class TestInfoOpus(PytorchTestCase):
assert info.encoding == "OPUS"
@skipIfNoExtension
@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing `format` allows to read mp3 without extension
......@@ -306,7 +306,7 @@ class FileObjTestBase(TempDirMixin):
return path
@skipIfNoExtension
@skipIfNoSox
@skipIfNoExec('sox')
class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames):
......@@ -438,7 +438,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoExtension
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
......
......@@ -11,8 +11,8 @@ from torchaudio_unittest.common_utils import (
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
skipIfNoSox,
get_asset_path,
get_wav_data,
load_wav,
......@@ -200,7 +200,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(list(itertools.product(
......@@ -332,7 +332,7 @@ class TestLoad(LoadTestBase):
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
original = None
......@@ -363,7 +363,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected)
@skipIfNoExtension
@skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension
......@@ -393,7 +393,7 @@ class CloggedFileObj:
return ret
@skipIfNoExtension
@skipIfNoSox
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
"""
......@@ -553,7 +553,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
self.assertEqual(expected, found)
@skipIfNoExtension
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
......
......@@ -7,7 +7,7 @@ from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoSox,
get_wav_data,
)
from .common import (
......@@ -17,7 +17,7 @@ from .common import (
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product(
......
......@@ -10,7 +10,7 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoSox,
get_wav_data,
load_wav,
save_wav,
......@@ -157,7 +157,7 @@ def nested_params(*params):
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class SaveTest(SaveTestBase):
@nested_params(
["path", "fileobj", "bytesio"],
......@@ -354,7 +354,7 @@ class SaveTest(SaveTestBase):
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=name_func)
......
......@@ -4,26 +4,26 @@ import unittest
from torchaudio.utils import sox_utils
from torchaudio.backend import sox_io_backend
from torchaudio._internal.module_utils import is_module_available
from torchaudio._internal.module_utils import is_sox_available
from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExtension,
skipIfNoSox,
get_wav_data,
)
from .common import name_func
skipIfNoMP3 = unittest.skipIf(
not is_module_available('torchaudio._torchaudio') or
not is_sox_available() or
'mp3' not in sox_utils.list_read_formats() or
'mp3' not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3')
@skipIfNoExtension
@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format
......@@ -88,7 +88,7 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
@skipIfNoExtension
@skipIfNoSox
class SmokeTestFileObj(TorchaudioTestCase):
"""Run smoke test on various audio format
......
......@@ -9,7 +9,7 @@ from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoSox,
get_wav_data,
save_wav,
load_wav,
......@@ -45,7 +45,7 @@ def py_save_func(
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io'
......
......@@ -25,7 +25,7 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
backend_module = torchaudio.backend.no_backend
@common_utils.skipIfNoExtension
@common_utils.skipIfNoSox
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'
backend_module = torchaudio.backend.sox_io_backend
......
......@@ -16,6 +16,7 @@ from .case_utils import (
skipIfNoExec,
skipIfNoModule,
skipIfNoExtension,
skipIfNoSox,
skipIfNoSoxBackend,
)
from .wav_utils import (
......@@ -30,5 +31,5 @@ from .parameterized_utils import (
__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoExtension', 'skipIfNoSoxBackend',
'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params']
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoExtension', 'skipIfNoSox',
'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params']
......@@ -8,7 +8,10 @@ import unittest
import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio
from torchaudio._internal.module_utils import is_module_available
from torchaudio._internal.module_utils import (
is_module_available,
is_sox_available
)
from .backend_utils import set_audio_backend
......@@ -95,6 +98,7 @@ def skipIfNoModule(module, display_name=None):
skipIfNoSoxBackend = unittest.skipIf(
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
def skipIfNoExtension(test_item):
......
......@@ -7,6 +7,7 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
get_whitenoise,
save_wav,
skipIfNoSox
)
from torchaudio.datasets import tedlium
......@@ -144,5 +145,6 @@ class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
if platform.system() != "Windows":
@skipIfNoSox
class TestTedliumSoxIO(Tedlium, TorchaudioTestCase):
backend = "sox_io"
......@@ -10,7 +10,7 @@ import itertools
from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
TorchaudioTestCase,
skipIfNoExtension,
skipIfNoSox,
)
from torchaudio_unittest.backend.sox_io.common import name_func
......@@ -220,7 +220,7 @@ class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase):
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
@skipIfNoExtension
@skipIfNoSox
class TestApplyCodec(TorchaudioTestCase):
backend = "sox_io"
......
......@@ -11,7 +11,7 @@ import torchaudio
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExtension,
skipIfNoSox,
get_whitenoise,
save_wav,
)
......@@ -71,7 +71,7 @@ def init_random_seed(worker_id):
dataset.rng = np.random.RandomState(worker_id)
@skipIfNoExtension
@skipIfNoSox
@skipIf(
platform.system() == 'Darwin' and
sys.version_info.major == 3 and
......@@ -134,7 +134,7 @@ def speed(path):
return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0]
@skipIfNoExtension
@skipIfNoSox
class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase):
backend = "sox_io"
......
......@@ -4,7 +4,7 @@ from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
TorchaudioTestCase,
skipIfNoExtension,
skipIfNoSox,
get_wav_data,
get_sinusoid,
save_wav,
......@@ -14,7 +14,7 @@ from .common import (
)
@skipIfNoExtension
@skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various effects
......
......@@ -11,7 +11,7 @@ from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExtension,
skipIfNoSox,
skipIfNoModule,
skipIfNoExec,
get_asset_path,
......@@ -31,7 +31,7 @@ if _mod_utils.is_module_available("requests"):
import requests
@skipIfNoExtension
@skipIfNoSox
class TestSoxEffects(PytorchTestCase):
def test_init(self):
"""Calling init_sox_effects multiple times does not crush"""
......@@ -39,7 +39,7 @@ class TestSoxEffects(PytorchTestCase):
sox_effects.init_sox_effects()
@skipIfNoExtension
@skipIfNoSox
class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase):
"""Test suite for `apply_effects_tensor` function"""
@parameterized.expand(list(itertools.product(
......@@ -91,7 +91,7 @@ class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase):
self.assertEqual(expected, found)
@skipIfNoExtension
@skipIfNoSox
class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
"""Test suite for `apply_effects_file` function"""
@parameterized.expand(list(itertools.product(
......@@ -163,7 +163,7 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected)
@skipIfNoExtension
@skipIfNoSox
class TestFileFormats(TempDirMixin, PytorchTestCase):
"""`apply_effects_file` gives the same result as sox on various file formats"""
@parameterized.expand(list(itertools.product(
......@@ -256,7 +256,7 @@ class TestFileFormats(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected)
@skipIfNoExtension
@skipIfNoSox
class TestApplyEffectFileWithoutExtension(PytorchTestCase):
def test_mp3(self):
"""Providing format allows to read mp3 without extension
......@@ -275,7 +275,7 @@ class TestApplyEffectFileWithoutExtension(PytorchTestCase):
@skipIfNoExec('sox')
@skipIfNoExtension
@skipIfNoSox
class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
......@@ -384,7 +384,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected)
@skipIfNoExtension
@skipIfNoSox
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
......
......@@ -7,7 +7,7 @@ from parameterized import parameterized
from torchaudio_unittest.common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExtension,
skipIfNoSox,
get_sinusoid,
save_wav,
)
......@@ -43,7 +43,7 @@ class SoxEffectFileTransform(torch.nn.Module):
return sox_effects.apply_effects_file(path, self.effects, self.channels_first)
@skipIfNoExtension
@skipIfNoSox
class TestTorchScript(TempDirMixin, PytorchTestCase):
@parameterized.expand(
load_params("sox_effect_test_args.json"),
......
......@@ -2,11 +2,11 @@ from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import (
PytorchTestCase,
skipIfNoExtension,
skipIfNoSox,
)
@skipIfNoExtension
@skipIfNoSox
class TestSoxUtils(PytorchTestCase):
"""Smoke tests for sox_util module"""
def test_set_seed(self):
......
......@@ -10,12 +10,8 @@ if (BUILD_SOX)
add_subdirectory(sox)
target_include_directories(libsox INTERFACE ${SOX_INCLUDE_DIR})
target_link_libraries(libsox INTERFACE ${SOX_LIBRARIES})
else()
# If not building and linking libsox statically, then we expect that
# sox library and header are found in search path
target_link_libraries(libsox INTERFACE -lsox)
list(APPEND TORCHAUDIO_THIRD_PARTIES libsox)
endif()
list(APPEND TORCHAUDIO_THIRD_PARTIES libsox)
################################################################################
# kaldi
......
......@@ -3,6 +3,8 @@ import importlib.util
from typing import Optional
from functools import wraps
import torch
def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
......@@ -56,3 +58,20 @@ def deprecated(direction: str, version: Optional[str] = None):
return func(*args, **kwargs)
return wrapped
return decorator
def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()
def requires_sox():
if is_sox_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires sox')
return wrapped
return decorator
......@@ -10,7 +10,7 @@ import torchaudio
from .common import AudioMetaData
@_mod_utils.requires_module('torchaudio._torchaudio')
@_mod_utils.requires_sox()
def info(
filepath: str,
format: Optional[str] = None,
......@@ -54,7 +54,7 @@ def info(
return AudioMetaData(*sinfo)
@_mod_utils.requires_module('torchaudio._torchaudio')
@_mod_utils.requires_sox()
def load(
filepath: str,
frame_offset: int = 0,
......@@ -151,7 +151,7 @@ def load(
filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_module('torchaudio._torchaudio')
@_mod_utils.requires_sox()
def save(
filepath: str,
src: torch.Tensor,
......
......@@ -3,7 +3,7 @@ import warnings
from typing import Optional, List
import torchaudio
from torchaudio._internal.module_utils import is_module_available
from torchaudio._internal import module_utils as _mod_utils
from . import (
no_backend,
sox_io_backend,
......@@ -24,9 +24,9 @@ def list_audio_backends() -> List[str]:
List[str]: The list of available backends.
"""
backends = []
if is_module_available('soundfile'):
if _mod_utils.is_module_available('soundfile'):
backends.append('soundfile')
if is_module_available('torchaudio._torchaudio'):
if _mod_utils.is_sox_available():
backends.append('sox_io')
return backends
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment