Unverified Commit 4b583eab authored by moto's avatar moto Committed by GitHub
Browse files

Add sox_io_backend (#726)

parent 894959a7
...@@ -2,7 +2,7 @@ import itertools ...@@ -2,7 +2,7 @@ import itertools
from typing import Optional from typing import Optional
import torch import torch
from torchaudio.backend import sox_io_backend import torchaudio
from parameterized import parameterized from parameterized import parameterized
from ..common_utils import ( from ..common_utils import (
...@@ -21,11 +21,11 @@ from .common import ( ...@@ -21,11 +21,11 @@ from .common import (
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath) return torchaudio.info(filepath)
def py_load_func(filepath: str, normalize: bool, channels_first: bool): def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return sox_io_backend.load( return torchaudio.load(
filepath, normalize=normalize, channels_first=channels_first) filepath, normalize=normalize, channels_first=channels_first)
...@@ -36,13 +36,15 @@ def py_save_func( ...@@ -36,13 +36,15 @@ def py_save_func(
channels_first: bool = True, channels_first: bool = True,
compression: Optional[float] = None, compression: Optional[float] = None,
): ):
sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression) torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase): class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`""" """TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io'
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
......
import unittest
import torchaudio import torchaudio
from . import common_utils from . import common_utils
...@@ -33,6 +31,12 @@ class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase) ...@@ -33,6 +31,12 @@ class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase)
backend_module = torchaudio.backend.sox_backend backend_module = torchaudio.backend.sox_backend
@common_utils.skipIfNoExtension
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'
backend_module = torchaudio.backend.sox_io_backend
@common_utils.skipIfNoModule('soundfile') @common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase): class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile' backend = 'soundfile'
......
...@@ -30,11 +30,15 @@ class Test_LoadSave(unittest.TestCase): ...@@ -30,11 +30,15 @@ class Test_LoadSave(unittest.TestCase):
def test_1_save(self): def test_1_save(self):
for backend in BACKENDS_MP3: for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath, False) self._test_1_save(self.test_filepath, False)
for backend in BACKENDS: for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath_wav, True) self._test_1_save(self.test_filepath_wav, True)
...@@ -81,6 +85,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -81,6 +85,8 @@ class Test_LoadSave(unittest.TestCase):
def test_1_save_sine(self): def test_1_save_sine(self):
for backend in BACKENDS: for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_1_save_sine() self._test_1_save_sine()
...@@ -114,11 +120,15 @@ class Test_LoadSave(unittest.TestCase): ...@@ -114,11 +120,15 @@ class Test_LoadSave(unittest.TestCase):
def test_2_load(self): def test_2_load(self):
for backend in BACKENDS_MP3: for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath, 278756) self._test_2_load(self.test_filepath, 278756)
for backend in BACKENDS: for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath_wav, 276858) self._test_2_load(self.test_filepath_wav, 276858)
...@@ -155,6 +165,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -155,6 +165,8 @@ class Test_LoadSave(unittest.TestCase):
def test_2_load_nonormalization(self): def test_2_load_nonormalization(self):
for backend in BACKENDS_MP3: for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_2_load_nonormalization(self.test_filepath, 278756) self._test_2_load_nonormalization(self.test_filepath, 278756)
...@@ -172,6 +184,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -172,6 +184,8 @@ class Test_LoadSave(unittest.TestCase):
def test_3_load_and_save_is_identity(self): def test_3_load_and_save_is_identity(self):
for backend in BACKENDS: for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_3_load_and_save_is_identity() self._test_3_load_and_save_is_identity()
...@@ -210,6 +224,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -210,6 +224,8 @@ class Test_LoadSave(unittest.TestCase):
def test_4_load_partial(self): def test_4_load_partial(self):
for backend in BACKENDS_MP3: for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_4_load_partial() self._test_4_load_partial()
...@@ -252,6 +268,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -252,6 +268,8 @@ class Test_LoadSave(unittest.TestCase):
def test_5_get_info(self): def test_5_get_info(self):
for backend in BACKENDS: for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest(): with self.subTest():
torchaudio.set_audio_backend(backend) torchaudio.set_audio_backend(backend)
self._test_5_get_info() self._test_5_get_info()
......
...@@ -7,6 +7,7 @@ from torchaudio._internal.module_utils import is_module_available ...@@ -7,6 +7,7 @@ from torchaudio._internal.module_utils import is_module_available
from . import ( from . import (
no_backend, no_backend,
sox_backend, sox_backend,
sox_io_backend,
soundfile_backend, soundfile_backend,
) )
...@@ -24,6 +25,7 @@ def list_audio_backends() -> List[str]: ...@@ -24,6 +25,7 @@ def list_audio_backends() -> List[str]:
backends.append('soundfile') backends.append('soundfile')
if is_module_available('torchaudio._torchaudio'): if is_module_available('torchaudio._torchaudio'):
backends.append('sox') backends.append('sox')
backends.append('sox_io')
return backends return backends
...@@ -43,6 +45,8 @@ def set_audio_backend(backend: Optional[str]) -> None: ...@@ -43,6 +45,8 @@ def set_audio_backend(backend: Optional[str]) -> None:
module = no_backend module = no_backend
elif backend == 'sox': elif backend == 'sox':
module = sox_backend module = sox_backend
elif backend == 'sox_io':
module = sox_io_backend
elif backend == 'soundfile': elif backend == 'soundfile':
module = soundfile_backend module = soundfile_backend
else: else:
...@@ -69,6 +73,8 @@ def get_audio_backend() -> Optional[str]: ...@@ -69,6 +73,8 @@ def get_audio_backend() -> Optional[str]:
return None return None
if torchaudio.load == sox_backend.load: if torchaudio.load == sox_backend.load:
return 'sox' return 'sox'
if torchaudio.load == sox_io_backend.load:
return 'sox_io'
if torchaudio.load == soundfile_backend.load: if torchaudio.load == soundfile_backend.load:
return 'soundfile' return 'soundfile'
raise ValueError('Unknown backend.') raise ValueError('Unknown backend.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment