Unverified Commit dde08ba1 authored by moto-meta's avatar moto-meta Committed by GitHub
Browse files

Simplify the logic to initialize sox

Differential Revision: D50197331

Pull Request resolved: https://github.com/pytorch/audio/pull/3654
parent f62367a6
...@@ -130,17 +130,4 @@ auto apply_effects_file( ...@@ -130,17 +130,4 @@ auto apply_effects_file(
return std::tuple<torch::Tensor, int64_t>( return std::tuple<torch::Tensor, int64_t>(
tensor, chain.getOutputSampleRate()); tensor, chain.getOutputSampleRate());
} }
namespace {
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
&initialize_sox_effects);
m.def("torchaudio::sox_effects_shutdown_sox_effects", &shutdown_sox_effects);
m.def("torchaudio::sox_effects_apply_effects_tensor", &apply_effects_tensor);
m.def("torchaudio::sox_effects_apply_effects_file", &apply_effects_file);
}
} // namespace
} // namespace torchaudio::sox } // namespace torchaudio::sox
...@@ -125,11 +125,4 @@ void save_audio_file( ...@@ -125,11 +125,4 @@ void save_audio_file(
chain.addOutputFile(sf); chain.addOutputFile(sf);
chain.run(); chain.run();
} }
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_io_get_info", &get_info_file);
m.def("torchaudio::sox_io_load_audio_file", &load_audio_file);
m.def("torchaudio::sox_io_save_audio_file", &save_audio_file);
}
} // namespace torchaudio::sox } // namespace torchaudio::sox
#include <libtorchaudio/sox/effects.h>
#include <libtorchaudio/sox/io.h>
#include <libtorchaudio/sox/utils.h> #include <libtorchaudio/sox/utils.h>
#include <torch/extension.h> #include <torch/extension.h>
...@@ -5,6 +7,16 @@ namespace torchaudio { ...@@ -5,6 +7,16 @@ namespace torchaudio {
namespace sox { namespace sox {
namespace { namespace {
TORCH_LIBRARY(torchaudio_sox, m) {
m.def("torchaudio_sox::get_info", &get_info_file);
m.def("torchaudio_sox::load_audio_file", &load_audio_file);
m.def("torchaudio_sox::save_audio_file", &save_audio_file);
m.def("torchaudio_sox::initialize_sox_effects", &initialize_sox_effects);
m.def("torchaudio_sox::shutdown_sox_effects", &shutdown_sox_effects);
m.def("torchaudio_sox::apply_effects_tensor", &apply_effects_tensor);
m.def("torchaudio_sox::apply_effects_file", &apply_effects_file);
}
PYBIND11_MODULE(_torchaudio_sox, m) { PYBIND11_MODULE(_torchaudio_sox, m) {
m.def("set_seed", &set_seed, "Set random seed."); m.def("set_seed", &set_seed, "Set random seed.");
m.def("set_verbosity", &set_verbosity, "Set verbosity."); m.def("set_verbosity", &set_verbosity, "Set verbosity.");
......
...@@ -2,10 +2,13 @@ import os ...@@ -2,10 +2,13 @@ import os
from typing import BinaryIO, Optional, Tuple, Union from typing import BinaryIO, Optional, Tuple, Union
import torch import torch
import torchaudio
from .backend import Backend from .backend import Backend
from .common import AudioMetaData from .common import AudioMetaData
sox_ext = torchaudio._extension.lazy_import_sox_ext()
class SoXBackend(Backend): class SoXBackend(Backend):
@staticmethod @staticmethod
...@@ -16,7 +19,7 @@ class SoXBackend(Backend): ...@@ -16,7 +19,7 @@ class SoXBackend(Backend):
"Please use an alternative backend that does support reading from file-like objects, e.g. FFmpeg.", "Please use an alternative backend that does support reading from file-like objects, e.g. FFmpeg.",
) )
else: else:
sinfo = torch.ops.torchaudio.sox_io_get_info(uri, format) sinfo = sox_ext.get_info(uri, format)
if sinfo: if sinfo:
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
else: else:
...@@ -38,9 +41,7 @@ class SoXBackend(Backend): ...@@ -38,9 +41,7 @@ class SoXBackend(Backend):
"Please use an alternative backend that does support loading from file-like objects, e.g. FFmpeg.", "Please use an alternative backend that does support loading from file-like objects, e.g. FFmpeg.",
) )
else: else:
ret = torch.ops.torchaudio.sox_io_load_audio_file( ret = sox_ext.load_audio_file(uri, frame_offset, num_frames, normalize, channels_first, format)
uri, frame_offset, num_frames, normalize, channels_first, format
)
if not ret: if not ret:
raise RuntimeError(f"Failed to load audio from {uri}.") raise RuntimeError(f"Failed to load audio from {uri}.")
return ret return ret
...@@ -62,7 +63,7 @@ class SoXBackend(Backend): ...@@ -62,7 +63,7 @@ class SoXBackend(Backend):
"Please use an alternative backend that does support writing to file-like objects, e.g. FFmpeg.", "Please use an alternative backend that does support writing to file-like objects, e.g. FFmpeg.",
) )
else: else:
torch.ops.torchaudio.sox_io_save_audio_file( sox_ext.save_audio_file(
uri, uri,
src, src,
sample_rate, sample_rate,
......
...@@ -4,7 +4,7 @@ from typing import BinaryIO, Dict, Optional, Tuple, Type, Union ...@@ -4,7 +4,7 @@ from typing import BinaryIO, Dict, Optional, Tuple, Type, Union
import torch import torch
from torchaudio._extension import _SOX_INITIALIZED, lazy_import_ffmpeg_ext from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext
from . import soundfile_backend from . import soundfile_backend
...@@ -20,7 +20,7 @@ def get_available_backends() -> Dict[str, Type[Backend]]: ...@@ -20,7 +20,7 @@ def get_available_backends() -> Dict[str, Type[Backend]]:
backend_specs: Dict[str, Type[Backend]] = {} backend_specs: Dict[str, Type[Backend]] = {}
if lazy_import_ffmpeg_ext().is_available(): if lazy_import_ffmpeg_ext().is_available():
backend_specs["ffmpeg"] = FFmpegBackend backend_specs["ffmpeg"] = FFmpegBackend
if _SOX_INITIALIZED: if lazy_import_sox_ext().is_available():
backend_specs["sox"] = SoXBackend backend_specs["sox"] = SoXBackend
if soundfile_backend._IS_SOUNDFILE_AVAILABLE: if soundfile_backend._IS_SOUNDFILE_AVAILABLE:
backend_specs["soundfile"] = SoundfileBackend backend_specs["soundfile"] = SoundfileBackend
......
...@@ -2,17 +2,9 @@ import logging ...@@ -2,17 +2,9 @@ import logging
import os import os
import sys import sys
from torchaudio._internal.module_utils import eval_env, fail_with_message, is_module_available, no_op from torchaudio._internal.module_utils import fail_with_message, is_module_available, no_op
from .utils import ( from .utils import _check_cuda_version, _init_dll_path, _init_ffmpeg, _init_sox, _LazyImporter, _load_lib
_check_cuda_version,
_fail_since_no_sox,
_init_dll_path,
_init_ffmpeg,
_init_sox,
_LazyImporter,
_load_lib,
)
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
...@@ -22,11 +14,10 @@ _LG = logging.getLogger(__name__) ...@@ -22,11 +14,10 @@ _LG = logging.getLogger(__name__)
# Builder uses it for debugging purpose, so we export it. # Builder uses it for debugging purpose, so we export it.
# https://github.com/pytorch/builder/blob/e2e4542b8eb0bdf491214451a1a4128bd606cce2/test/smoke_test/smoke_test.py#L80 # https://github.com/pytorch/builder/blob/e2e4542b8eb0bdf491214451a1a4128bd606cce2/test/smoke_test/smoke_test.py#L80
__all__ = [ __all__ = [
"fail_if_no_sox",
"_check_cuda_version", "_check_cuda_version",
"_IS_TORCHAUDIO_EXT_AVAILABLE", "_IS_TORCHAUDIO_EXT_AVAILABLE",
"_IS_RIR_AVAILABLE", "_IS_RIR_AVAILABLE",
"_SOX_INITIALIZED", "lazy_import_sox_ext",
"lazy_import_ffmpeg_ext", "lazy_import_ffmpeg_ext",
] ]
...@@ -54,34 +45,16 @@ if _IS_TORCHAUDIO_EXT_AVAILABLE: ...@@ -54,34 +45,16 @@ if _IS_TORCHAUDIO_EXT_AVAILABLE:
_IS_ALIGN_AVAILABLE = torchaudio.lib._torchaudio.is_align_available() _IS_ALIGN_AVAILABLE = torchaudio.lib._torchaudio.is_align_available()
# Initialize libsox-related features _SOX_EXT = None
_SOX_INITIALIZED = False
_USE_SOX = False if os.name == "nt" else eval_env("TORCHAUDIO_USE_SOX", True)
_SOX_MODULE_AVAILABLE = is_module_available("torchaudio.lib._torchaudio_sox") def lazy_import_sox_ext():
if _USE_SOX and _SOX_MODULE_AVAILABLE: """Load SoX integration based on availability in lazy manner"""
try:
_init_sox() global _SOX_EXT
_SOX_INITIALIZED = True if _SOX_EXT is None:
except Exception: _SOX_EXT = _LazyImporter("_torchaudio_sox", _init_sox)
# The initialization of sox extension will fail if supported sox return _SOX_EXT
# libraries are not found in the system.
# Since the rest of the torchaudio works without it, we do not report the
# error here.
# The error will be raised when user code attempts to use these features.
_LG.debug("Failed to initialize sox extension", exc_info=True)
if os.name == "nt":
fail_if_no_sox = fail_with_message("requires sox extension, which is not supported on Windows.")
elif not _USE_SOX:
fail_if_no_sox = fail_with_message("requires sox extension, but it is disabled. (TORCHAUDIO_USE_SOX=0)")
elif not _SOX_MODULE_AVAILABLE:
fail_if_no_sox = fail_with_message(
"requires sox extension, but TorchAudio is not compiled with it. "
"Please build TorchAudio with libsox support. (BUILD_SOX=1)"
)
else:
fail_if_no_sox = no_op if _SOX_INITIALIZED else _fail_since_no_sox
_FFMPEG_EXT = None _FFMPEG_EXT = None
......
...@@ -9,10 +9,10 @@ import importlib ...@@ -9,10 +9,10 @@ import importlib
import logging import logging
import os import os
import types import types
from functools import wraps
from pathlib import Path from pathlib import Path
import torch import torch
from torchaudio._internal.module_utils import eval_env
_LG = logging.getLogger(__name__) _LG = logging.getLogger(__name__)
_LIB_DIR = Path(__file__).parent.parent / "lib" _LIB_DIR = Path(__file__).parent.parent / "lib"
...@@ -62,16 +62,49 @@ def _load_lib(lib: str) -> bool: ...@@ -62,16 +62,49 @@ def _load_lib(lib: str) -> bool:
return True return True
def _init_sox(): def _import_sox_ext():
if os.name == "nt":
raise RuntimeError("sox extension is not supported on Windows")
if not eval_env("TORCHAUDIO_USE_SOX", True):
raise RuntimeError("sox extension is disabled. (TORCHAUDIO_USE_SOX=0)")
ext = "torchaudio.lib._torchaudio_sox"
if not importlib.util.find_spec(ext):
raise RuntimeError(
# fmt: off
"TorchAudio is not built with sox extension. "
"Please build TorchAudio with libsox support. (BUILD_SOX=1)"
# fmt: on
)
_load_lib("libtorchaudio_sox") _load_lib("libtorchaudio_sox")
import torchaudio.lib._torchaudio_sox # noqa return importlib.import_module(ext)
torchaudio.lib._torchaudio_sox.set_verbosity(0) def _init_sox():
ext = _import_sox_ext()
ext.set_verbosity(0)
import atexit import atexit
torch.ops.torchaudio.sox_effects_initialize_sox_effects() torch.ops.torchaudio_sox.initialize_sox_effects()
atexit.register(torch.ops.torchaudio.sox_effects_shutdown_sox_effects) atexit.register(torch.ops.torchaudio_sox.shutdown_sox_effects)
# Bundle functions registered with TORCH_LIBRARY into extension
# so that they can also be accessed in the same (lazy) manner
# from the extension.
keys = [
"get_info",
"load_audio_file",
"save_audio_file",
"apply_effects_tensor",
"apply_effects_file",
]
for key in keys:
setattr(ext, key, getattr(torch.ops.torchaudio_sox, key))
return ext
_FFMPEG_VERS = ["6", "5", "4", ""] _FFMPEG_VERS = ["6", "5", "4", ""]
...@@ -197,22 +230,3 @@ def _check_cuda_version(): ...@@ -197,22 +230,3 @@ def _check_cuda_version():
"Please install the TorchAudio version that matches your PyTorch version." "Please install the TorchAudio version that matches your PyTorch version."
) )
return version return version
def _fail_since_no_sox(func):
@wraps(func)
def wrapped(*_args, **_kwargs):
try:
# Note:
# We run _init_sox again just to show users the stacktrace.
# _init_sox would not succeed here.
_init_sox()
except Exception as err:
raise RuntimeError(
f"{func.__name__} requires sox extension which is not available. "
"Please refer to the stacktrace above for how to resolve this."
) from err
# This should not happen in normal execution, but just in case.
return func(*_args, **_kwargs)
return wrapped
...@@ -5,8 +5,9 @@ import torch ...@@ -5,8 +5,9 @@ import torch
import torchaudio import torchaudio
from torchaudio import AudioMetaData from torchaudio import AudioMetaData
sox_ext = torchaudio._extension.lazy_import_sox_ext()
@torchaudio._extension.fail_if_no_sox
def info( def info(
filepath: str, filepath: str,
format: Optional[str] = None, format: Optional[str] = None,
...@@ -29,11 +30,10 @@ def info( ...@@ -29,11 +30,10 @@ def info(
if hasattr(filepath, "read"): if hasattr(filepath, "read"):
raise RuntimeError("sox_io backend does not support file-like object.") raise RuntimeError("sox_io backend does not support file-like object.")
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format) sinfo = sox_ext.get_info(filepath, format)
return AudioMetaData(*sinfo) return AudioMetaData(*sinfo)
@torchaudio._extension.fail_if_no_sox
def load( def load(
filepath: str, filepath: str,
frame_offset: int = 0, frame_offset: int = 0,
...@@ -123,12 +123,9 @@ def load( ...@@ -123,12 +123,9 @@ def load(
if hasattr(filepath, "read"): if hasattr(filepath, "read"):
raise RuntimeError("sox_io backend does not support file-like object.") raise RuntimeError("sox_io backend does not support file-like object.")
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
return torch.ops.torchaudio.sox_io_load_audio_file( return sox_ext.load_audio_file(filepath, frame_offset, num_frames, normalize, channels_first, format)
filepath, frame_offset, num_frames, normalize, channels_first, format
)
@torchaudio._extension.fail_if_no_sox
def save( def save(
filepath: str, filepath: str,
src: torch.Tensor, src: torch.Tensor,
...@@ -285,7 +282,7 @@ def save( ...@@ -285,7 +282,7 @@ def save(
if hasattr(filepath, "write"): if hasattr(filepath, "write"):
raise RuntimeError("sox_io backend does not handle file-like object.") raise RuntimeError("sox_io backend does not handle file-like object.")
filepath = os.fspath(filepath) filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file( sox_ext.save_audio_file(
filepath, filepath,
src, src,
sample_rate, sample_rate,
......
...@@ -1295,7 +1295,6 @@ def spectral_centroid( ...@@ -1295,7 +1295,6 @@ def spectral_centroid(
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
@torchaudio._extension.fail_if_no_sox
@deprecated("Please migrate to :py:class:`torchaudio.io.AudioEffector`.", remove=False) @deprecated("Please migrate to :py:class:`torchaudio.io.AudioEffector`.", remove=False)
def apply_codec( def apply_codec(
waveform: Tensor, waveform: Tensor,
...@@ -1329,11 +1328,13 @@ def apply_codec( ...@@ -1329,11 +1328,13 @@ def apply_codec(
Tensor: Resulting Tensor. Tensor: Resulting Tensor.
If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`. If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
""" """
from torchaudio.backend import _sox_io_backend
with tempfile.NamedTemporaryFile() as f: with tempfile.NamedTemporaryFile() as f:
torchaudio.backend.sox_io_backend.save( torchaudio.backend._sox_io_backend.save(
f.name, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample f.name, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
) )
augmented, sr = torchaudio.backend.sox_io_backend.load(f.name, channels_first=channels_first, format=format) augmented, sr = _sox_io_backend.load(f.name, channels_first=channels_first, format=format)
if sr != sample_rate: if sr != sample_rate:
augmented = resample(augmented, sr, sample_rate) augmented = resample(augmented, sr, sample_rate)
return augmented return augmented
...@@ -1371,7 +1372,8 @@ def _get_sinc_resample_kernel( ...@@ -1371,7 +1372,8 @@ def _get_sinc_resample_kernel(
warnings.warn( warnings.warn(
f'"{resampling_method}" resampling method name is being deprecated and replaced by ' f'"{resampling_method}" resampling method name is being deprecated and replaced by '
f'"{method_map[resampling_method]}" in the next release. ' f'"{method_map[resampling_method]}" in the next release. '
"The default behavior remains unchanged." "The default behavior remains unchanged.",
stacklevel=3,
) )
elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]: elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]:
raise ValueError("Invalid resampling method: {}".format(resampling_method)) raise ValueError("Invalid resampling method: {}".format(resampling_method))
......
...@@ -7,6 +7,9 @@ from torchaudio._internal.module_utils import deprecated ...@@ -7,6 +7,9 @@ from torchaudio._internal.module_utils import deprecated
from torchaudio.utils.sox_utils import list_effects from torchaudio.utils.sox_utils import list_effects
sox_ext = torchaudio._extension.lazy_import_sox_ext()
@deprecated("Please remove the call. This function is called automatically.") @deprecated("Please remove the call. This function is called automatically.")
def init_sox_effects(): def init_sox_effects():
"""Initialize resources required to use sox effects. """Initialize resources required to use sox effects.
...@@ -36,7 +39,6 @@ def shutdown_sox_effects(): ...@@ -36,7 +39,6 @@ def shutdown_sox_effects():
pass pass
@torchaudio._extension.fail_if_no_sox
def effect_names() -> List[str]: def effect_names() -> List[str]:
"""Gets list of valid sox effect names """Gets list of valid sox effect names
...@@ -50,7 +52,6 @@ def effect_names() -> List[str]: ...@@ -50,7 +52,6 @@ def effect_names() -> List[str]:
return list(list_effects().keys()) return list(list_effects().keys())
@torchaudio._extension.fail_if_no_sox
def apply_effects_tensor( def apply_effects_tensor(
tensor: torch.Tensor, tensor: torch.Tensor,
sample_rate: int, sample_rate: int,
...@@ -152,10 +153,9 @@ def apply_effects_tensor( ...@@ -152,10 +153,9 @@ def apply_effects_tensor(
>>> waveform, sample_rate = transform(waveform, input_sample_rate) >>> waveform, sample_rate = transform(waveform, input_sample_rate)
>>> assert sample_rate == 8000 >>> assert sample_rate == 8000
""" """
return torch.ops.torchaudio.sox_effects_apply_effects_tensor(tensor, sample_rate, effects, channels_first) return sox_ext.apply_effects_tensor(tensor, sample_rate, effects, channels_first)
@torchaudio._extension.fail_if_no_sox
def apply_effects_file( def apply_effects_file(
path: str, path: str,
effects: List[List[str]], effects: List[List[str]],
...@@ -269,4 +269,4 @@ def apply_effects_file( ...@@ -269,4 +269,4 @@ def apply_effects_file(
"Please use torchaudio.io.AudioEffector." "Please use torchaudio.io.AudioEffector."
) )
path = os.fspath(path) path = os.fspath(path)
return torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format) return sox_ext.apply_effects_file(path, effects, normalize, channels_first, format)
...@@ -6,8 +6,9 @@ from typing import Dict, List ...@@ -6,8 +6,9 @@ from typing import Dict, List
import torchaudio import torchaudio
sox_ext = torchaudio._extension.lazy_import_sox_ext()
@torchaudio._extension.fail_if_no_sox
def set_seed(seed: int): def set_seed(seed: int):
"""Set libsox's PRNG """Set libsox's PRNG
...@@ -17,10 +18,9 @@ def set_seed(seed: int): ...@@ -17,10 +18,9 @@ def set_seed(seed: int):
See Also: See Also:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
""" """
torchaudio.lib._torchaudio_sox.set_seed(seed) sox_ext.set_seed(seed)
@torchaudio._extension.fail_if_no_sox
def set_verbosity(verbosity: int): def set_verbosity(verbosity: int):
"""Set libsox's verbosity """Set libsox's verbosity
...@@ -35,10 +35,9 @@ def set_verbosity(verbosity: int): ...@@ -35,10 +35,9 @@ def set_verbosity(verbosity: int):
See Also: See Also:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
""" """
torchaudio.lib._torchaudio_sox.set_verbosity(verbosity) sox_ext.set_verbosity(verbosity)
@torchaudio._extension.fail_if_no_sox
def set_buffer_size(buffer_size: int): def set_buffer_size(buffer_size: int):
"""Set buffer size for sox effect chain """Set buffer size for sox effect chain
...@@ -48,10 +47,9 @@ def set_buffer_size(buffer_size: int): ...@@ -48,10 +47,9 @@ def set_buffer_size(buffer_size: int):
See Also: See Also:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
""" """
torchaudio.lib._torchaudio_sox.set_buffer_size(buffer_size) sox_ext.set_buffer_size(buffer_size)
@torchaudio._extension.fail_if_no_sox
def set_use_threads(use_threads: bool): def set_use_threads(use_threads: bool):
"""Set multithread option for sox effect chain """Set multithread option for sox effect chain
...@@ -62,44 +60,40 @@ def set_use_threads(use_threads: bool): ...@@ -62,44 +60,40 @@ def set_use_threads(use_threads: bool):
See Also: See Also:
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
""" """
torchaudio.lib._torchaudio_sox.set_use_threads(use_threads) sox_ext.set_use_threads(use_threads)
@torchaudio._extension.fail_if_no_sox
def list_effects() -> Dict[str, str]: def list_effects() -> Dict[str, str]:
"""List the available sox effect names """List the available sox effect names
Returns: Returns:
Dict[str, str]: Mapping from ``effect name`` to ``usage`` Dict[str, str]: Mapping from ``effect name`` to ``usage``
""" """
return dict(torchaudio.lib._torchaudio_sox.list_effects()) return dict(sox_ext.list_effects())
@torchaudio._extension.fail_if_no_sox
def list_read_formats() -> List[str]: def list_read_formats() -> List[str]:
"""List the supported audio formats for read """List the supported audio formats for read
Returns: Returns:
List[str]: List of supported audio formats List[str]: List of supported audio formats
""" """
return torchaudio.lib._torchaudio_sox.list_read_formats() return sox_ext.list_read_formats()
@torchaudio._extension.fail_if_no_sox
def list_write_formats() -> List[str]: def list_write_formats() -> List[str]:
"""List the supported audio formats for write """List the supported audio formats for write
Returns: Returns:
List[str]: List of supported audio formats List[str]: List of supported audio formats
""" """
return torchaudio.lib._torchaudio_sox.list_write_formats() return sox_ext.list_write_formats()
@torchaudio._extension.fail_if_no_sox
def get_buffer_size() -> int: def get_buffer_size() -> int:
"""Get buffer size for sox effect chain """Get buffer size for sox effect chain
Returns: Returns:
int: size in bytes of buffers used for processing audio. int: size in bytes of buffers used for processing audio.
""" """
return torchaudio.lib._torchaudio_sox.get_buffer_size() return sox_ext.get_buffer_size()
...@@ -293,6 +293,8 @@ class TestLoad(LoadTestBase): ...@@ -293,6 +293,8 @@ class TestLoad(LoadTestBase):
class TestLoadParams(TempDirMixin, PytorchTestCase): class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`""" """Test the correctness of frame parameters of `sox_io_backend.load`"""
_load = partial(get_load_func(), backend="sox")
def _test(self, func, frame_offset, num_frames, channels_first, normalize): def _test(self, func, frame_offset, num_frames, channels_first, normalize):
original = get_wav_data("int16", num_channels=2, normalize=False) original = get_wav_data("int16", num_channels=2, normalize=False)
path = self.get_temp_path("test.wav") path = self.get_temp_path("test.wav")
...@@ -316,7 +318,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase): ...@@ -316,7 +318,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
def test_sox(self, frame_offset, num_frames, channels_first, normalize): def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor""" """The combination of properly changes the output tensor"""
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) self._test(self._load, frame_offset, num_frames, channels_first, normalize)
@skipIfNoSox @skipIfNoSox
......
...@@ -313,7 +313,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase): ...@@ -313,7 +313,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
def test_sox(self, frame_offset, num_frames, channels_first, normalize): def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor""" """The combination of properly changes the output tensor"""
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize) self._test(sox_io_backend.load, frame_offset, num_frames, channels_first, normalize)
@skipIfNoSox @skipIfNoSox
......
...@@ -112,6 +112,7 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase): ...@@ -112,6 +112,7 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
_IS_FFMPEG_AVAILABLE = torchaudio._extension.lazy_import_ffmpeg_ext().is_available() _IS_FFMPEG_AVAILABLE = torchaudio._extension.lazy_import_ffmpeg_ext().is_available()
_IS_SOX_AVAILABLE = torchaudio._extension.lazy_import_sox_ext().is_available()
_IS_CTC_DECODER_AVAILABLE = None _IS_CTC_DECODER_AVAILABLE = None
_IS_CUDA_CTC_DECODER_AVAILABLE = None _IS_CUDA_CTC_DECODER_AVAILABLE = None
...@@ -209,7 +210,7 @@ skipIfCudaSmallMemory = _skipIf( ...@@ -209,7 +210,7 @@ skipIfCudaSmallMemory = _skipIf(
key="CUDA_SMALL_MEMORY", key="CUDA_SMALL_MEMORY",
) )
skipIfNoSox = _skipIf( skipIfNoSox = _skipIf(
not torchaudio._extension._SOX_INITIALIZED, not _IS_SOX_AVAILABLE,
reason="Sox features are not available.", reason="Sox features are not available.",
key="NO_SOX", key="NO_SOX",
) )
...@@ -217,7 +218,7 @@ skipIfNoSox = _skipIf( ...@@ -217,7 +218,7 @@ skipIfNoSox = _skipIf(
def skipIfNoSoxDecoder(ext): def skipIfNoSoxDecoder(ext):
return _skipIf( return _skipIf(
not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_read_formats(), not _IS_SOX_AVAILABLE or ext not in torchaudio.utils.sox_utils.list_read_formats(),
f'sox does not handle "{ext}" for read.', f'sox does not handle "{ext}" for read.',
key="NO_SOX_DECODER", key="NO_SOX_DECODER",
) )
...@@ -225,7 +226,7 @@ def skipIfNoSoxDecoder(ext): ...@@ -225,7 +226,7 @@ def skipIfNoSoxDecoder(ext):
def skipIfNoSoxEncoder(ext): def skipIfNoSoxEncoder(ext):
return _skipIf( return _skipIf(
not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_write_formats(), not _IS_SOX_AVAILABLE or ext not in torchaudio.utils.sox_utils.list_write_formats(),
f'sox does not handle "{ext}" for write.', f'sox does not handle "{ext}" for write.',
key="NO_SOX_ENCODER", key="NO_SOX_ENCODER",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment