Unverified Commit dde08ba1 authored by moto-meta's avatar moto-meta Committed by GitHub
Browse files

Simplify the logic to initialize sox

Differential Revision: D50197331

Pull Request resolved: https://github.com/pytorch/audio/pull/3654
parent f62367a6
......@@ -130,17 +130,4 @@ auto apply_effects_file(
return std::tuple<torch::Tensor, int64_t>(
tensor, chain.getOutputSampleRate());
}
namespace {
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
&initialize_sox_effects);
m.def("torchaudio::sox_effects_shutdown_sox_effects", &shutdown_sox_effects);
m.def("torchaudio::sox_effects_apply_effects_tensor", &apply_effects_tensor);
m.def("torchaudio::sox_effects_apply_effects_file", &apply_effects_file);
}
} // namespace
} // namespace torchaudio::sox
......@@ -125,11 +125,4 @@ void save_audio_file(
chain.addOutputFile(sf);
chain.run();
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_io_get_info", &get_info_file);
m.def("torchaudio::sox_io_load_audio_file", &load_audio_file);
m.def("torchaudio::sox_io_save_audio_file", &save_audio_file);
}
} // namespace torchaudio::sox
#include <libtorchaudio/sox/effects.h>
#include <libtorchaudio/sox/io.h>
#include <libtorchaudio/sox/utils.h>
#include <torch/extension.h>
......@@ -5,6 +7,16 @@ namespace torchaudio {
namespace sox {
namespace {
TORCH_LIBRARY(torchaudio_sox, m) {
m.def("torchaudio_sox::get_info", &get_info_file);
m.def("torchaudio_sox::load_audio_file", &load_audio_file);
m.def("torchaudio_sox::save_audio_file", &save_audio_file);
m.def("torchaudio_sox::initialize_sox_effects", &initialize_sox_effects);
m.def("torchaudio_sox::shutdown_sox_effects", &shutdown_sox_effects);
m.def("torchaudio_sox::apply_effects_tensor", &apply_effects_tensor);
m.def("torchaudio_sox::apply_effects_file", &apply_effects_file);
}
PYBIND11_MODULE(_torchaudio_sox, m) {
m.def("set_seed", &set_seed, "Set random seed.");
m.def("set_verbosity", &set_verbosity, "Set verbosity.");
......
......@@ -2,10 +2,13 @@ import os
from typing import BinaryIO, Optional, Tuple, Union
import torch
import torchaudio
from .backend import Backend
from .common import AudioMetaData
sox_ext = torchaudio._extension.lazy_import_sox_ext()
class SoXBackend(Backend):
@staticmethod
......@@ -16,7 +19,7 @@ class SoXBackend(Backend):
"Please use an alternative backend that does support reading from file-like objects, e.g. FFmpeg.",
)
else:
sinfo = torch.ops.torchaudio.sox_io_get_info(uri, format)
sinfo = sox_ext.get_info(uri, format)
if sinfo:
return AudioMetaData(*sinfo)
else:
......@@ -38,9 +41,7 @@ class SoXBackend(Backend):
"Please use an alternative backend that does support loading from file-like objects, e.g. FFmpeg.",
)
else:
ret = torch.ops.torchaudio.sox_io_load_audio_file(
uri, frame_offset, num_frames, normalize, channels_first, format
)
ret = sox_ext.load_audio_file(uri, frame_offset, num_frames, normalize, channels_first, format)
if not ret:
raise RuntimeError(f"Failed to load audio from {uri}.")
return ret
......@@ -62,7 +63,7 @@ class SoXBackend(Backend):
"Please use an alternative backend that does support writing to file-like objects, e.g. FFmpeg.",
)
else:
torch.ops.torchaudio.sox_io_save_audio_file(
sox_ext.save_audio_file(
uri,
src,
sample_rate,
......
......@@ -4,7 +4,7 @@ from typing import BinaryIO, Dict, Optional, Tuple, Type, Union
import torch
from torchaudio._extension import _SOX_INITIALIZED, lazy_import_ffmpeg_ext
from torchaudio._extension import lazy_import_ffmpeg_ext, lazy_import_sox_ext
from . import soundfile_backend
......@@ -20,7 +20,7 @@ def get_available_backends() -> Dict[str, Type[Backend]]:
backend_specs: Dict[str, Type[Backend]] = {}
if lazy_import_ffmpeg_ext().is_available():
backend_specs["ffmpeg"] = FFmpegBackend
if _SOX_INITIALIZED:
if lazy_import_sox_ext().is_available():
backend_specs["sox"] = SoXBackend
if soundfile_backend._IS_SOUNDFILE_AVAILABLE:
backend_specs["soundfile"] = SoundfileBackend
......
......@@ -2,17 +2,9 @@ import logging
import os
import sys
from torchaudio._internal.module_utils import eval_env, fail_with_message, is_module_available, no_op
from .utils import (
_check_cuda_version,
_fail_since_no_sox,
_init_dll_path,
_init_ffmpeg,
_init_sox,
_LazyImporter,
_load_lib,
)
from torchaudio._internal.module_utils import fail_with_message, is_module_available, no_op
from .utils import _check_cuda_version, _init_dll_path, _init_ffmpeg, _init_sox, _LazyImporter, _load_lib
_LG = logging.getLogger(__name__)
......@@ -22,11 +14,10 @@ _LG = logging.getLogger(__name__)
# Builder uses it for debugging purpose, so we export it.
# https://github.com/pytorch/builder/blob/e2e4542b8eb0bdf491214451a1a4128bd606cce2/test/smoke_test/smoke_test.py#L80
__all__ = [
"fail_if_no_sox",
"_check_cuda_version",
"_IS_TORCHAUDIO_EXT_AVAILABLE",
"_IS_RIR_AVAILABLE",
"_SOX_INITIALIZED",
"lazy_import_sox_ext",
"lazy_import_ffmpeg_ext",
]
......@@ -54,34 +45,16 @@ if _IS_TORCHAUDIO_EXT_AVAILABLE:
_IS_ALIGN_AVAILABLE = torchaudio.lib._torchaudio.is_align_available()
# Initialize libsox-related features
_SOX_INITIALIZED = False
_USE_SOX = False if os.name == "nt" else eval_env("TORCHAUDIO_USE_SOX", True)
_SOX_MODULE_AVAILABLE = is_module_available("torchaudio.lib._torchaudio_sox")
if _USE_SOX and _SOX_MODULE_AVAILABLE:
try:
_init_sox()
_SOX_INITIALIZED = True
except Exception:
# The initialization of sox extension will fail if supported sox
# libraries are not found in the system.
# Since the rest of the torchaudio works without it, we do not report the
# error here.
# The error will be raised when user code attempts to use these features.
_LG.debug("Failed to initialize sox extension", exc_info=True)
if os.name == "nt":
fail_if_no_sox = fail_with_message("requires sox extension, which is not supported on Windows.")
elif not _USE_SOX:
fail_if_no_sox = fail_with_message("requires sox extension, but it is disabled. (TORCHAUDIO_USE_SOX=0)")
elif not _SOX_MODULE_AVAILABLE:
fail_if_no_sox = fail_with_message(
"requires sox extension, but TorchAudio is not compiled with it. "
"Please build TorchAudio with libsox support. (BUILD_SOX=1)"
)
else:
fail_if_no_sox = no_op if _SOX_INITIALIZED else _fail_since_no_sox
_SOX_EXT = None
def lazy_import_sox_ext():
"""Load SoX integration based on availability in lazy manner"""
global _SOX_EXT
if _SOX_EXT is None:
_SOX_EXT = _LazyImporter("_torchaudio_sox", _init_sox)
return _SOX_EXT
_FFMPEG_EXT = None
......
......@@ -9,10 +9,10 @@ import importlib
import logging
import os
import types
from functools import wraps
from pathlib import Path
import torch
from torchaudio._internal.module_utils import eval_env
_LG = logging.getLogger(__name__)
_LIB_DIR = Path(__file__).parent.parent / "lib"
......@@ -62,16 +62,49 @@ def _load_lib(lib: str) -> bool:
return True
def _init_sox():
def _import_sox_ext():
if os.name == "nt":
raise RuntimeError("sox extension is not supported on Windows")
if not eval_env("TORCHAUDIO_USE_SOX", True):
raise RuntimeError("sox extension is disabled. (TORCHAUDIO_USE_SOX=0)")
ext = "torchaudio.lib._torchaudio_sox"
if not importlib.util.find_spec(ext):
raise RuntimeError(
# fmt: off
"TorchAudio is not built with sox extension. "
"Please build TorchAudio with libsox support. (BUILD_SOX=1)"
# fmt: on
)
_load_lib("libtorchaudio_sox")
import torchaudio.lib._torchaudio_sox # noqa
return importlib.import_module(ext)
torchaudio.lib._torchaudio_sox.set_verbosity(0)
def _init_sox():
ext = _import_sox_ext()
ext.set_verbosity(0)
import atexit
torch.ops.torchaudio.sox_effects_initialize_sox_effects()
atexit.register(torch.ops.torchaudio.sox_effects_shutdown_sox_effects)
torch.ops.torchaudio_sox.initialize_sox_effects()
atexit.register(torch.ops.torchaudio_sox.shutdown_sox_effects)
# Bundle functions registered with TORCH_LIBRARY into extension
# so that they can also be accessed in the same (lazy) manner
# from the extension.
keys = [
"get_info",
"load_audio_file",
"save_audio_file",
"apply_effects_tensor",
"apply_effects_file",
]
for key in keys:
setattr(ext, key, getattr(torch.ops.torchaudio_sox, key))
return ext
_FFMPEG_VERS = ["6", "5", "4", ""]
......@@ -197,22 +230,3 @@ def _check_cuda_version():
"Please install the TorchAudio version that matches your PyTorch version."
)
return version
def _fail_since_no_sox(func):
@wraps(func)
def wrapped(*_args, **_kwargs):
try:
# Note:
# We run _init_sox again just to show users the stacktrace.
# _init_sox would not succeed here.
_init_sox()
except Exception as err:
raise RuntimeError(
f"{func.__name__} requires sox extension which is not available. "
"Please refer to the stacktrace above for how to resolve this."
) from err
# This should not happen in normal execution, but just in case.
return func(*_args, **_kwargs)
return wrapped
......@@ -5,8 +5,9 @@ import torch
import torchaudio
from torchaudio import AudioMetaData
sox_ext = torchaudio._extension.lazy_import_sox_ext()
@torchaudio._extension.fail_if_no_sox
def info(
filepath: str,
format: Optional[str] = None,
......@@ -29,11 +30,10 @@ def info(
if hasattr(filepath, "read"):
raise RuntimeError("sox_io backend does not support file-like object.")
filepath = os.fspath(filepath)
sinfo = torch.ops.torchaudio.sox_io_get_info(filepath, format)
sinfo = sox_ext.get_info(filepath, format)
return AudioMetaData(*sinfo)
@torchaudio._extension.fail_if_no_sox
def load(
filepath: str,
frame_offset: int = 0,
......@@ -123,12 +123,9 @@ def load(
if hasattr(filepath, "read"):
raise RuntimeError("sox_io backend does not support file-like object.")
filepath = os.fspath(filepath)
return torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format
)
return sox_ext.load_audio_file(filepath, frame_offset, num_frames, normalize, channels_first, format)
@torchaudio._extension.fail_if_no_sox
def save(
filepath: str,
src: torch.Tensor,
......@@ -285,7 +282,7 @@ def save(
if hasattr(filepath, "write"):
raise RuntimeError("sox_io backend does not handle file-like object.")
filepath = os.fspath(filepath)
torch.ops.torchaudio.sox_io_save_audio_file(
sox_ext.save_audio_file(
filepath,
src,
sample_rate,
......
......@@ -1295,7 +1295,6 @@ def spectral_centroid(
return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim)
@torchaudio._extension.fail_if_no_sox
@deprecated("Please migrate to :py:class:`torchaudio.io.AudioEffector`.", remove=False)
def apply_codec(
waveform: Tensor,
......@@ -1329,11 +1328,13 @@ def apply_codec(
Tensor: Resulting Tensor.
If ``channels_first=True``, it has `(channel, time)` else `(time, channel)`.
"""
from torchaudio.backend import _sox_io_backend
with tempfile.NamedTemporaryFile() as f:
torchaudio.backend.sox_io_backend.save(
torchaudio.backend._sox_io_backend.save(
f.name, waveform, sample_rate, channels_first, compression, format, encoding, bits_per_sample
)
augmented, sr = torchaudio.backend.sox_io_backend.load(f.name, channels_first=channels_first, format=format)
augmented, sr = _sox_io_backend.load(f.name, channels_first=channels_first, format=format)
if sr != sample_rate:
augmented = resample(augmented, sr, sample_rate)
return augmented
......@@ -1371,7 +1372,8 @@ def _get_sinc_resample_kernel(
warnings.warn(
f'"{resampling_method}" resampling method name is being deprecated and replaced by '
f'"{method_map[resampling_method]}" in the next release. '
"The default behavior remains unchanged."
"The default behavior remains unchanged.",
stacklevel=3,
)
elif resampling_method not in ["sinc_interp_hann", "sinc_interp_kaiser"]:
raise ValueError("Invalid resampling method: {}".format(resampling_method))
......
......@@ -7,6 +7,9 @@ from torchaudio._internal.module_utils import deprecated
from torchaudio.utils.sox_utils import list_effects
sox_ext = torchaudio._extension.lazy_import_sox_ext()
@deprecated("Please remove the call. This function is called automatically.")
def init_sox_effects():
"""Initialize resources required to use sox effects.
......@@ -36,7 +39,6 @@ def shutdown_sox_effects():
pass
@torchaudio._extension.fail_if_no_sox
def effect_names() -> List[str]:
"""Gets list of valid sox effect names
......@@ -50,7 +52,6 @@ def effect_names() -> List[str]:
return list(list_effects().keys())
@torchaudio._extension.fail_if_no_sox
def apply_effects_tensor(
tensor: torch.Tensor,
sample_rate: int,
......@@ -152,10 +153,9 @@ def apply_effects_tensor(
>>> waveform, sample_rate = transform(waveform, input_sample_rate)
>>> assert sample_rate == 8000
"""
return torch.ops.torchaudio.sox_effects_apply_effects_tensor(tensor, sample_rate, effects, channels_first)
return sox_ext.apply_effects_tensor(tensor, sample_rate, effects, channels_first)
@torchaudio._extension.fail_if_no_sox
def apply_effects_file(
path: str,
effects: List[List[str]],
......@@ -269,4 +269,4 @@ def apply_effects_file(
"Please use torchaudio.io.AudioEffector."
)
path = os.fspath(path)
return torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first, format)
return sox_ext.apply_effects_file(path, effects, normalize, channels_first, format)
......@@ -6,8 +6,9 @@ from typing import Dict, List
import torchaudio
sox_ext = torchaudio._extension.lazy_import_sox_ext()
@torchaudio._extension.fail_if_no_sox
def set_seed(seed: int):
"""Set libsox's PRNG
......@@ -17,10 +18,9 @@ def set_seed(seed: int):
See Also:
http://sox.sourceforge.net/sox.html
"""
torchaudio.lib._torchaudio_sox.set_seed(seed)
sox_ext.set_seed(seed)
@torchaudio._extension.fail_if_no_sox
def set_verbosity(verbosity: int):
"""Set libsox's verbosity
......@@ -35,10 +35,9 @@ def set_verbosity(verbosity: int):
See Also:
http://sox.sourceforge.net/sox.html
"""
torchaudio.lib._torchaudio_sox.set_verbosity(verbosity)
sox_ext.set_verbosity(verbosity)
@torchaudio._extension.fail_if_no_sox
def set_buffer_size(buffer_size: int):
"""Set buffer size for sox effect chain
......@@ -48,10 +47,9 @@ def set_buffer_size(buffer_size: int):
See Also:
http://sox.sourceforge.net/sox.html
"""
torchaudio.lib._torchaudio_sox.set_buffer_size(buffer_size)
sox_ext.set_buffer_size(buffer_size)
@torchaudio._extension.fail_if_no_sox
def set_use_threads(use_threads: bool):
"""Set multithread option for sox effect chain
......@@ -62,44 +60,40 @@ def set_use_threads(use_threads: bool):
See Also:
http://sox.sourceforge.net/sox.html
"""
torchaudio.lib._torchaudio_sox.set_use_threads(use_threads)
sox_ext.set_use_threads(use_threads)
@torchaudio._extension.fail_if_no_sox
def list_effects() -> Dict[str, str]:
"""List the available sox effect names
Returns:
Dict[str, str]: Mapping from ``effect name`` to ``usage``
"""
return dict(torchaudio.lib._torchaudio_sox.list_effects())
return dict(sox_ext.list_effects())
@torchaudio._extension.fail_if_no_sox
def list_read_formats() -> List[str]:
"""List the supported audio formats for read
Returns:
List[str]: List of supported audio formats
"""
return torchaudio.lib._torchaudio_sox.list_read_formats()
return sox_ext.list_read_formats()
@torchaudio._extension.fail_if_no_sox
def list_write_formats() -> List[str]:
"""List the supported audio formats for write
Returns:
List[str]: List of supported audio formats
"""
return torchaudio.lib._torchaudio_sox.list_write_formats()
return sox_ext.list_write_formats()
@torchaudio._extension.fail_if_no_sox
def get_buffer_size() -> int:
"""Get buffer size for sox effect chain
Returns:
int: size in bytes of buffers used for processing audio.
"""
return torchaudio.lib._torchaudio_sox.get_buffer_size()
return sox_ext.get_buffer_size()
......@@ -293,6 +293,8 @@ class TestLoad(LoadTestBase):
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
_load = partial(get_load_func(), backend="sox")
def _test(self, func, frame_offset, num_frames, channels_first, normalize):
original = get_wav_data("int16", num_channels=2, normalize=False)
path = self.get_temp_path("test.wav")
......@@ -316,7 +318,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize)
self._test(self._load, frame_offset, num_frames, channels_first, normalize)
@skipIfNoSox
......
......@@ -313,7 +313,7 @@ class TestLoadParams(TempDirMixin, PytorchTestCase):
def test_sox(self, frame_offset, num_frames, channels_first, normalize):
"""The combination of properly changes the output tensor"""
self._test(torch.ops.torchaudio.sox_io_load_audio_file, frame_offset, num_frames, channels_first, normalize)
self._test(sox_io_backend.load, frame_offset, num_frames, channels_first, normalize)
@skipIfNoSox
......
......@@ -112,6 +112,7 @@ class TorchaudioTestCase(TestBaseMixin, PytorchTestCase):
_IS_FFMPEG_AVAILABLE = torchaudio._extension.lazy_import_ffmpeg_ext().is_available()
_IS_SOX_AVAILABLE = torchaudio._extension.lazy_import_sox_ext().is_available()
_IS_CTC_DECODER_AVAILABLE = None
_IS_CUDA_CTC_DECODER_AVAILABLE = None
......@@ -209,7 +210,7 @@ skipIfCudaSmallMemory = _skipIf(
key="CUDA_SMALL_MEMORY",
)
skipIfNoSox = _skipIf(
not torchaudio._extension._SOX_INITIALIZED,
not _IS_SOX_AVAILABLE,
reason="Sox features are not available.",
key="NO_SOX",
)
......@@ -217,7 +218,7 @@ skipIfNoSox = _skipIf(
def skipIfNoSoxDecoder(ext):
return _skipIf(
not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_read_formats(),
not _IS_SOX_AVAILABLE or ext not in torchaudio.utils.sox_utils.list_read_formats(),
f'sox does not handle "{ext}" for read.',
key="NO_SOX_DECODER",
)
......@@ -225,7 +226,7 @@ def skipIfNoSoxDecoder(ext):
def skipIfNoSoxEncoder(ext):
return _skipIf(
not torchaudio._extension._SOX_INITIALIZED or ext not in torchaudio.utils.sox_utils.list_write_formats(),
not _IS_SOX_AVAILABLE or ext not in torchaudio.utils.sox_utils.list_write_formats(),
f'sox does not handle "{ext}" for write.',
key="NO_SOX_ENCODER",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment