Unverified Commit ecfed4d9 authored by Caroline Chen's avatar Caroline Chen Committed by GitHub
Browse files

Make sox selective (#1338)

parent 98d0d593
...@@ -16,8 +16,8 @@ from torchaudio_unittest.common_utils import ( ...@@ -16,8 +16,8 @@ from torchaudio_unittest.common_utils import (
HttpServerMixin, HttpServerMixin,
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension,
skipIfNoModule, skipIfNoModule,
skipIfNoSox,
get_asset_path, get_asset_path,
get_wav_data, get_wav_data,
save_wav, save_wav,
...@@ -33,7 +33,7 @@ if _mod_utils.is_module_available("requests"): ...@@ -33,7 +33,7 @@ if _mod_utils.is_module_available("requests"):
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class TestInfo(TempDirMixin, PytorchTestCase): class TestInfo(TempDirMixin, PytorchTestCase):
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
...@@ -253,7 +253,7 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -253,7 +253,7 @@ class TestInfo(TempDirMixin, PytorchTestCase):
assert info.encoding == "PCM_S" assert info.encoding == "PCM_S"
@skipIfNoExtension @skipIfNoSox
class TestInfoOpus(PytorchTestCase): class TestInfoOpus(PytorchTestCase):
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['96k'], ['96k'],
...@@ -271,7 +271,7 @@ class TestInfoOpus(PytorchTestCase): ...@@ -271,7 +271,7 @@ class TestInfoOpus(PytorchTestCase):
assert info.encoding == "OPUS" assert info.encoding == "OPUS"
@skipIfNoExtension @skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase): class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self): def test_mp3(self):
"""Providing `format` allows to read mp3 without extension """Providing `format` allows to read mp3 without extension
...@@ -306,7 +306,7 @@ class FileObjTestBase(TempDirMixin): ...@@ -306,7 +306,7 @@ class FileObjTestBase(TempDirMixin):
return path return path
@skipIfNoExtension @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec('sox')
class TestFileObject(FileObjTestBase, PytorchTestCase): class TestFileObject(FileObjTestBase, PytorchTestCase):
def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames): def _query_fileobj(self, ext, dtype, sample_rate, num_channels, num_frames):
...@@ -438,7 +438,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase): ...@@ -438,7 +438,7 @@ class TestFileObject(FileObjTestBase, PytorchTestCase):
assert sinfo.encoding == get_encoding(ext, dtype) assert sinfo.encoding == get_encoding(ext, dtype)
@skipIfNoExtension @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoModule("requests") @skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase): class TestFileObjectHttp(HttpServerMixin, FileObjTestBase, PytorchTestCase):
......
...@@ -11,8 +11,8 @@ from torchaudio_unittest.common_utils import ( ...@@ -11,8 +11,8 @@ from torchaudio_unittest.common_utils import (
HttpServerMixin, HttpServerMixin,
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension,
skipIfNoModule, skipIfNoModule,
skipIfNoSox,
get_asset_path, get_asset_path,
get_wav_data, get_wav_data,
load_wav, load_wav,
...@@ -200,7 +200,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase): ...@@ -200,7 +200,7 @@ class LoadTestBase(TempDirMixin, PytorchTestCase):
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class TestLoad(LoadTestBase): class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats""" """Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -332,7 +332,7 @@ class TestLoad(LoadTestBase): ...@@ -332,7 +332,7 @@ class TestLoad(LoadTestBase):
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class TestLoadParams(TempDirMixin, PytorchTestCase): class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`""" """Test the correctness of frame parameters of `sox_io_backend.load`"""
original = None original = None
...@@ -363,7 +363,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase): ...@@ -363,7 +363,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected) self.assertEqual(found, expected)
@skipIfNoExtension @skipIfNoSox
class TestLoadWithoutExtension(PytorchTestCase): class TestLoadWithoutExtension(PytorchTestCase):
def test_mp3(self): def test_mp3(self):
"""Providing format allows to read mp3 without extension """Providing format allows to read mp3 without extension
...@@ -393,7 +393,7 @@ class CloggedFileObj: ...@@ -393,7 +393,7 @@ class CloggedFileObj:
return ret return ret
@skipIfNoExtension @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase): class TestFileObject(TempDirMixin, PytorchTestCase):
""" """
...@@ -553,7 +553,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -553,7 +553,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
self.assertEqual(expected, found) self.assertEqual(expected, found)
@skipIfNoExtension @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoModule("requests") @skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
......
...@@ -7,7 +7,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -7,7 +7,7 @@ from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoSox,
get_wav_data, get_wav_data,
) )
from .common import ( from .common import (
...@@ -17,7 +17,7 @@ from .common import ( ...@@ -17,7 +17,7 @@ from .common import (
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class TestRoundTripIO(TempDirMixin, PytorchTestCase): class TestRoundTripIO(TempDirMixin, PytorchTestCase):
"""save/load round trip should not degrade data for lossless formats""" """save/load round trip should not degrade data for lossless formats"""
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
......
...@@ -10,7 +10,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -10,7 +10,7 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
PytorchTestCase, PytorchTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoSox,
get_wav_data, get_wav_data,
load_wav, load_wav,
save_wav, save_wav,
...@@ -157,7 +157,7 @@ def nested_params(*params): ...@@ -157,7 +157,7 @@ def nested_params(*params):
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class SaveTest(SaveTestBase): class SaveTest(SaveTestBase):
@nested_params( @nested_params(
["path", "fileobj", "bytesio"], ["path", "fileobj", "bytesio"],
...@@ -354,7 +354,7 @@ class SaveTest(SaveTestBase): ...@@ -354,7 +354,7 @@ class SaveTest(SaveTestBase):
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class TestSaveParams(TempDirMixin, PytorchTestCase): class TestSaveParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of optional parameters of `sox_io_backend.save`""" """Test the correctness of optional parameters of `sox_io_backend.save`"""
@parameterized.expand([(True, ), (False, )], name_func=name_func) @parameterized.expand([(True, ), (False, )], name_func=name_func)
......
...@@ -4,26 +4,26 @@ import unittest ...@@ -4,26 +4,26 @@ import unittest
from torchaudio.utils import sox_utils from torchaudio.utils import sox_utils
from torchaudio.backend import sox_io_backend from torchaudio.backend import sox_io_backend
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import is_sox_available
from parameterized import parameterized from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoExtension, skipIfNoSox,
get_wav_data, get_wav_data,
) )
from .common import name_func from .common import name_func
skipIfNoMP3 = unittest.skipIf( skipIfNoMP3 = unittest.skipIf(
not is_module_available('torchaudio._torchaudio') or not is_sox_available() or
'mp3' not in sox_utils.list_read_formats() or 'mp3' not in sox_utils.list_read_formats() or
'mp3' not in sox_utils.list_write_formats(), 'mp3' not in sox_utils.list_write_formats(),
'"sox_io" backend does not support MP3') '"sox_io" backend does not support MP3')
@skipIfNoExtension @skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase): class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various audio format """Run smoke test on various audio format
...@@ -88,7 +88,7 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase): ...@@ -88,7 +88,7 @@ class SmokeTest(TempDirMixin, TorchaudioTestCase):
self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level) self.run_smoke_test('flac', sample_rate, num_channels, compression=compression_level)
@skipIfNoExtension @skipIfNoSox
class SmokeTestFileObj(TorchaudioTestCase): class SmokeTestFileObj(TorchaudioTestCase):
"""Run smoke test on various audio format """Run smoke test on various audio format
......
...@@ -9,7 +9,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -9,7 +9,7 @@ from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoExec, skipIfNoExec,
skipIfNoExtension, skipIfNoSox,
get_wav_data, get_wav_data,
save_wav, save_wav,
load_wav, load_wav,
...@@ -45,7 +45,7 @@ def py_save_func( ...@@ -45,7 +45,7 @@ def py_save_func(
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class SoxIO(TempDirMixin, TorchaudioTestCase): class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`""" """TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io' backend = 'sox_io'
......
...@@ -25,7 +25,7 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes ...@@ -25,7 +25,7 @@ class TestBackendSwitch_NoBackend(BackendSwitchMixin, common_utils.TorchaudioTes
backend_module = torchaudio.backend.no_backend backend_module = torchaudio.backend.no_backend
@common_utils.skipIfNoExtension @common_utils.skipIfNoSox
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase): class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io' backend = 'sox_io'
backend_module = torchaudio.backend.sox_io_backend backend_module = torchaudio.backend.sox_io_backend
......
...@@ -16,6 +16,7 @@ from .case_utils import ( ...@@ -16,6 +16,7 @@ from .case_utils import (
skipIfNoExec, skipIfNoExec,
skipIfNoModule, skipIfNoModule,
skipIfNoExtension, skipIfNoExtension,
skipIfNoSox,
skipIfNoSoxBackend, skipIfNoSoxBackend,
) )
from .wav_utils import ( from .wav_utils import (
...@@ -30,5 +31,5 @@ from .parameterized_utils import ( ...@@ -30,5 +31,5 @@ from .parameterized_utils import (
__all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend', __all__ = ['get_asset_path', 'get_whitenoise', 'get_sinusoid', 'set_audio_backend',
'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase', 'TempDirMixin', 'HttpServerMixin', 'TestBaseMixin', 'PytorchTestCase', 'TorchaudioTestCase',
'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoExtension', 'skipIfNoSoxBackend', 'skipIfNoCuda', 'skipIfNoExec', 'skipIfNoModule', 'skipIfNoExtension', 'skipIfNoSox',
'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params'] 'skipIfNoSoxBackend', 'get_wav_data', 'normalize_wav', 'load_wav', 'save_wav', 'load_params']
...@@ -8,7 +8,10 @@ import unittest ...@@ -8,7 +8,10 @@ import unittest
import torch import torch
from torch.testing._internal.common_utils import TestCase as PytorchTestCase from torch.testing._internal.common_utils import TestCase as PytorchTestCase
import torchaudio import torchaudio
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal.module_utils import (
is_module_available,
is_sox_available
)
from .backend_utils import set_audio_backend from .backend_utils import set_audio_backend
...@@ -95,6 +98,7 @@ def skipIfNoModule(module, display_name=None): ...@@ -95,6 +98,7 @@ def skipIfNoModule(module, display_name=None):
skipIfNoSoxBackend = unittest.skipIf( skipIfNoSoxBackend = unittest.skipIf(
'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available') 'sox' not in torchaudio.list_audio_backends(), 'Sox backend not available')
skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available') skipIfNoCuda = unittest.skipIf(not torch.cuda.is_available(), reason='CUDA not available')
skipIfNoSox = unittest.skipIf(not is_sox_available(), reason='Sox not available')
def skipIfNoExtension(test_item): def skipIfNoExtension(test_item):
......
...@@ -7,6 +7,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -7,6 +7,7 @@ from torchaudio_unittest.common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
get_whitenoise, get_whitenoise,
save_wav, save_wav,
skipIfNoSox
) )
from torchaudio.datasets import tedlium from torchaudio.datasets import tedlium
...@@ -144,5 +145,6 @@ class TestTedliumSoundfile(Tedlium, TorchaudioTestCase): ...@@ -144,5 +145,6 @@ class TestTedliumSoundfile(Tedlium, TorchaudioTestCase):
if platform.system() != "Windows": if platform.system() != "Windows":
@skipIfNoSox
class TestTedliumSoxIO(Tedlium, TorchaudioTestCase): class TestTedliumSoxIO(Tedlium, TorchaudioTestCase):
backend = "sox_io" backend = "sox_io"
...@@ -10,7 +10,7 @@ import itertools ...@@ -10,7 +10,7 @@ import itertools
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoExtension, skipIfNoSox,
) )
from torchaudio_unittest.backend.sox_io.common import name_func from torchaudio_unittest.backend.sox_io.common import name_func
...@@ -220,7 +220,7 @@ class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase): ...@@ -220,7 +220,7 @@ class TestMaskAlongAxisIID(common_utils.TorchaudioTestCase):
assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel() assert (num_masked_columns < mask_param).sum() == num_masked_columns.numel()
@skipIfNoExtension @skipIfNoSox
class TestApplyCodec(TorchaudioTestCase): class TestApplyCodec(TorchaudioTestCase):
backend = "sox_io" backend = "sox_io"
......
...@@ -11,7 +11,7 @@ import torchaudio ...@@ -11,7 +11,7 @@ import torchaudio
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
skipIfNoExtension, skipIfNoSox,
get_whitenoise, get_whitenoise,
save_wav, save_wav,
) )
...@@ -71,7 +71,7 @@ def init_random_seed(worker_id): ...@@ -71,7 +71,7 @@ def init_random_seed(worker_id):
dataset.rng = np.random.RandomState(worker_id) dataset.rng = np.random.RandomState(worker_id)
@skipIfNoExtension @skipIfNoSox
@skipIf( @skipIf(
platform.system() == 'Darwin' and platform.system() == 'Darwin' and
sys.version_info.major == 3 and sys.version_info.major == 3 and
...@@ -134,7 +134,7 @@ def speed(path): ...@@ -134,7 +134,7 @@ def speed(path):
return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0] return torchaudio.sox_effects.apply_effects_tensor(wav, sample_rate, effects)[0]
@skipIfNoExtension @skipIfNoSox
class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase): class TestProcessPoolExecutor(TempDirMixin, PytorchTestCase):
backend = "sox_io" backend = "sox_io"
......
...@@ -4,7 +4,7 @@ from parameterized import parameterized ...@@ -4,7 +4,7 @@ from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
TorchaudioTestCase, TorchaudioTestCase,
skipIfNoExtension, skipIfNoSox,
get_wav_data, get_wav_data,
get_sinusoid, get_sinusoid,
save_wav, save_wav,
...@@ -14,7 +14,7 @@ from .common import ( ...@@ -14,7 +14,7 @@ from .common import (
) )
@skipIfNoExtension @skipIfNoSox
class SmokeTest(TempDirMixin, TorchaudioTestCase): class SmokeTest(TempDirMixin, TorchaudioTestCase):
"""Run smoke test on various effects """Run smoke test on various effects
......
...@@ -11,7 +11,7 @@ from torchaudio_unittest.common_utils import ( ...@@ -11,7 +11,7 @@ from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
HttpServerMixin, HttpServerMixin,
PytorchTestCase, PytorchTestCase,
skipIfNoExtension, skipIfNoSox,
skipIfNoModule, skipIfNoModule,
skipIfNoExec, skipIfNoExec,
get_asset_path, get_asset_path,
...@@ -31,7 +31,7 @@ if _mod_utils.is_module_available("requests"): ...@@ -31,7 +31,7 @@ if _mod_utils.is_module_available("requests"):
import requests import requests
@skipIfNoExtension @skipIfNoSox
class TestSoxEffects(PytorchTestCase): class TestSoxEffects(PytorchTestCase):
def test_init(self): def test_init(self):
"""Calling init_sox_effects multiple times does not crush""" """Calling init_sox_effects multiple times does not crush"""
...@@ -39,7 +39,7 @@ class TestSoxEffects(PytorchTestCase): ...@@ -39,7 +39,7 @@ class TestSoxEffects(PytorchTestCase):
sox_effects.init_sox_effects() sox_effects.init_sox_effects()
@skipIfNoExtension @skipIfNoSox
class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase): class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase):
"""Test suite for `apply_effects_tensor` function""" """Test suite for `apply_effects_tensor` function"""
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -91,7 +91,7 @@ class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase): ...@@ -91,7 +91,7 @@ class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase):
self.assertEqual(expected, found) self.assertEqual(expected, found)
@skipIfNoExtension @skipIfNoSox
class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
"""Test suite for `apply_effects_file` function""" """Test suite for `apply_effects_file` function"""
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -163,7 +163,7 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): ...@@ -163,7 +163,7 @@ class TestSoxEffectsFile(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected) self.assertEqual(found, expected)
@skipIfNoExtension @skipIfNoSox
class TestFileFormats(TempDirMixin, PytorchTestCase): class TestFileFormats(TempDirMixin, PytorchTestCase):
"""`apply_effects_file` gives the same result as sox on various file formats""" """`apply_effects_file` gives the same result as sox on various file formats"""
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
...@@ -256,7 +256,7 @@ class TestFileFormats(TempDirMixin, PytorchTestCase): ...@@ -256,7 +256,7 @@ class TestFileFormats(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected) self.assertEqual(found, expected)
@skipIfNoExtension @skipIfNoSox
class TestApplyEffectFileWithoutExtension(PytorchTestCase): class TestApplyEffectFileWithoutExtension(PytorchTestCase):
def test_mp3(self): def test_mp3(self):
"""Providing format allows to read mp3 without extension """Providing format allows to read mp3 without extension
...@@ -275,7 +275,7 @@ class TestApplyEffectFileWithoutExtension(PytorchTestCase): ...@@ -275,7 +275,7 @@ class TestApplyEffectFileWithoutExtension(PytorchTestCase):
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoSox
class TestFileObject(TempDirMixin, PytorchTestCase): class TestFileObject(TempDirMixin, PytorchTestCase):
@parameterized.expand([ @parameterized.expand([
('wav', None), ('wav', None),
...@@ -384,7 +384,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase): ...@@ -384,7 +384,7 @@ class TestFileObject(TempDirMixin, PytorchTestCase):
self.assertEqual(found, expected) self.assertEqual(found, expected)
@skipIfNoExtension @skipIfNoSox
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoModule("requests") @skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase): class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
......
...@@ -7,7 +7,7 @@ from parameterized import parameterized ...@@ -7,7 +7,7 @@ from parameterized import parameterized
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
TempDirMixin, TempDirMixin,
PytorchTestCase, PytorchTestCase,
skipIfNoExtension, skipIfNoSox,
get_sinusoid, get_sinusoid,
save_wav, save_wav,
) )
...@@ -43,7 +43,7 @@ class SoxEffectFileTransform(torch.nn.Module): ...@@ -43,7 +43,7 @@ class SoxEffectFileTransform(torch.nn.Module):
return sox_effects.apply_effects_file(path, self.effects, self.channels_first) return sox_effects.apply_effects_file(path, self.effects, self.channels_first)
@skipIfNoExtension @skipIfNoSox
class TestTorchScript(TempDirMixin, PytorchTestCase): class TestTorchScript(TempDirMixin, PytorchTestCase):
@parameterized.expand( @parameterized.expand(
load_params("sox_effect_test_args.json"), load_params("sox_effect_test_args.json"),
......
...@@ -2,11 +2,11 @@ from torchaudio.utils import sox_utils ...@@ -2,11 +2,11 @@ from torchaudio.utils import sox_utils
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
PytorchTestCase, PytorchTestCase,
skipIfNoExtension, skipIfNoSox,
) )
@skipIfNoExtension @skipIfNoSox
class TestSoxUtils(PytorchTestCase): class TestSoxUtils(PytorchTestCase):
"""Smoke tests for sox_util module""" """Smoke tests for sox_util module"""
def test_set_seed(self): def test_set_seed(self):
......
...@@ -10,12 +10,8 @@ if (BUILD_SOX) ...@@ -10,12 +10,8 @@ if (BUILD_SOX)
add_subdirectory(sox) add_subdirectory(sox)
target_include_directories(libsox INTERFACE ${SOX_INCLUDE_DIR}) target_include_directories(libsox INTERFACE ${SOX_INCLUDE_DIR})
target_link_libraries(libsox INTERFACE ${SOX_LIBRARIES}) target_link_libraries(libsox INTERFACE ${SOX_LIBRARIES})
else() list(APPEND TORCHAUDIO_THIRD_PARTIES libsox)
# If not building and linking libsox statically, then we expect that
# sox library and header are found in search path
target_link_libraries(libsox INTERFACE -lsox)
endif() endif()
list(APPEND TORCHAUDIO_THIRD_PARTIES libsox)
################################################################################ ################################################################################
# kaldi # kaldi
......
...@@ -3,6 +3,8 @@ import importlib.util ...@@ -3,6 +3,8 @@ import importlib.util
from typing import Optional from typing import Optional
from functools import wraps from functools import wraps
import torch
def is_module_available(*modules: str) -> bool: def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without** r"""Returns if a top-level module with :attr:`name` exists *without**
...@@ -56,3 +58,20 @@ def deprecated(direction: str, version: Optional[str] = None): ...@@ -56,3 +58,20 @@ def deprecated(direction: str, version: Optional[str] = None):
return func(*args, **kwargs) return func(*args, **kwargs)
return wrapped return wrapped
return decorator return decorator
def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()
def requires_sox():
if is_sox_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires sox')
return wrapped
return decorator
...@@ -10,7 +10,7 @@ import torchaudio ...@@ -10,7 +10,7 @@ import torchaudio
from .common import AudioMetaData from .common import AudioMetaData
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_sox()
def info( def info(
filepath: str, filepath: str,
format: Optional[str] = None, format: Optional[str] = None,
...@@ -54,7 +54,7 @@ def info( ...@@ -54,7 +54,7 @@ def info(
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_sox()
def load( def load(
filepath: str, filepath: str,
frame_offset: int = 0, frame_offset: int = 0,
...@@ -151,7 +151,7 @@ def load( ...@@ -151,7 +151,7 @@ def load(
filepath, frame_offset, num_frames, normalize, channels_first, format) filepath, frame_offset, num_frames, normalize, channels_first, format)
@_mod_utils.requires_module('torchaudio._torchaudio') @_mod_utils.requires_sox()
def save( def save(
filepath: str, filepath: str,
src: torch.Tensor, src: torch.Tensor,
......
...@@ -3,7 +3,7 @@ import warnings ...@@ -3,7 +3,7 @@ import warnings
from typing import Optional, List from typing import Optional, List
import torchaudio import torchaudio
from torchaudio._internal.module_utils import is_module_available from torchaudio._internal import module_utils as _mod_utils
from . import ( from . import (
no_backend, no_backend,
sox_io_backend, sox_io_backend,
...@@ -24,9 +24,9 @@ def list_audio_backends() -> List[str]: ...@@ -24,9 +24,9 @@ def list_audio_backends() -> List[str]:
List[str]: The list of available backends. List[str]: The list of available backends.
""" """
backends = [] backends = []
if is_module_available('soundfile'): if _mod_utils.is_module_available('soundfile'):
backends.append('soundfile') backends.append('soundfile')
if is_module_available('torchaudio._torchaudio'): if _mod_utils.is_sox_available():
backends.append('sox_io') backends.append('sox_io')
return backends return backends
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment