Unverified Commit 4b583eab authored by moto's avatar moto Committed by GitHub
Browse files

Add sox_io_backend (#726)

parent 894959a7
......@@ -2,7 +2,7 @@ import itertools
from typing import Optional
import torch
from torchaudio.backend import sox_io_backend
import torchaudio
from parameterized import parameterized
from ..common_utils import (
......@@ -21,11 +21,11 @@ from .common import (
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath)
return torchaudio.info(filepath)
def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return sox_io_backend.load(
return torchaudio.load(
filepath, normalize=normalize, channels_first=channels_first)
......@@ -36,13 +36,15 @@ def py_save_func(
channels_first: bool = True,
compression: Optional[float] = None,
):
sox_io_backend.save(filepath, tensor, sample_rate, channels_first, compression)
torchaudio.save(filepath, tensor, sample_rate, channels_first, compression)
@skipIfNoExec('sox')
@skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`"""
backend = 'sox_io'
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
......
import unittest
import torchaudio
from . import common_utils
......@@ -33,6 +31,12 @@ class TestBackendSwitch_SoX(BackendSwitchMixin, common_utils.TorchaudioTestCase)
backend_module = torchaudio.backend.sox_backend
@common_utils.skipIfNoExtension
class TestBackendSwitch_SoXIO(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'sox_io'
backend_module = torchaudio.backend.sox_io_backend
@common_utils.skipIfNoModule('soundfile')
class TestBackendSwitch_soundfile(BackendSwitchMixin, common_utils.TorchaudioTestCase):
backend = 'soundfile'
......
......@@ -30,11 +30,15 @@ class Test_LoadSave(unittest.TestCase):
def test_1_save(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath, False)
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save(self.test_filepath_wav, True)
......@@ -81,6 +85,8 @@ class Test_LoadSave(unittest.TestCase):
def test_1_save_sine(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_1_save_sine()
......@@ -114,11 +120,15 @@ class Test_LoadSave(unittest.TestCase):
def test_2_load(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath, 278756)
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load(self.test_filepath_wav, 276858)
......@@ -155,6 +165,8 @@ class Test_LoadSave(unittest.TestCase):
def test_2_load_nonormalization(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_2_load_nonormalization(self.test_filepath, 278756)
......@@ -172,6 +184,8 @@ class Test_LoadSave(unittest.TestCase):
def test_3_load_and_save_is_identity(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_3_load_and_save_is_identity()
......@@ -210,6 +224,8 @@ class Test_LoadSave(unittest.TestCase):
def test_4_load_partial(self):
for backend in BACKENDS_MP3:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_4_load_partial()
......@@ -252,6 +268,8 @@ class Test_LoadSave(unittest.TestCase):
def test_5_get_info(self):
for backend in BACKENDS:
if backend == 'sox_io':
continue
with self.subTest():
torchaudio.set_audio_backend(backend)
self._test_5_get_info()
......
......@@ -7,6 +7,7 @@ from torchaudio._internal.module_utils import is_module_available
from . import (
no_backend,
sox_backend,
sox_io_backend,
soundfile_backend,
)
......@@ -24,6 +25,7 @@ def list_audio_backends() -> List[str]:
backends.append('soundfile')
if is_module_available('torchaudio._torchaudio'):
backends.append('sox')
backends.append('sox_io')
return backends
......@@ -43,6 +45,8 @@ def set_audio_backend(backend: Optional[str]) -> None:
module = no_backend
elif backend == 'sox':
module = sox_backend
elif backend == 'sox_io':
module = sox_io_backend
elif backend == 'soundfile':
module = soundfile_backend
else:
......@@ -69,6 +73,8 @@ def get_audio_backend() -> Optional[str]:
return None
if torchaudio.load == sox_backend.load:
return 'sox'
if torchaudio.load == sox_io_backend.load:
return 'sox_io'
if torchaudio.load == soundfile_backend.load:
return 'soundfile'
raise ValueError('Unknown backend.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment