Unverified Commit e61b77dc authored by moto's avatar moto Committed by GitHub
Browse files

Base guard logic on module availability (#692)

* Replace `backed_guard` with `requires_module`

* Remove backend_guard
parent 2416e5d0
......@@ -7,9 +7,9 @@ import torch
from torch.testing._internal.common_utils import TestCase
import torchaudio
import torchaudio.functional as F
from torchaudio.common_utils import _check_module_exists
from torchaudio._internal.module_utils import is_module_available
LIBROSA_AVAILABLE = _check_module_exists('librosa')
LIBROSA_AVAILABLE = is_module_available('librosa')
if LIBROSA_AVAILABLE:
import numpy as np
......
......@@ -14,12 +14,12 @@ from torchaudio import (
)
from torchaudio._backend import (
check_input,
_audio_backend_guard,
_get_audio_backend_module,
get_audio_backend,
set_audio_backend,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
from torchaudio._internal import module_utils as _mod_utils
try:
from .version import __version__, git_version # noqa: F401
......@@ -27,6 +27,10 @@ except ImportError:
pass
if _mod_utils.is_module_available('torchaudio._torchaudio'):
from . import _torchaudio
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
......@@ -75,7 +79,6 @@ def load(filepath: Union[str, Path],
1.
"""
return _get_audio_backend_module().load(
filepath,
out=out,
......@@ -124,7 +127,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
)
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def save_encinfo(filepath: str,
src: Tensor,
channels_first: bool = True,
......@@ -192,8 +195,6 @@ def save_encinfo(filepath: str,
src = src.transpose(1, 0)
# save data to file
src = src.contiguous()
from . import _torchaudio
_torchaudio.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
......@@ -210,12 +211,11 @@ def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
Example
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""
"""
return _get_audio_backend_module().info(filepath)
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def sox_signalinfo_t() -> SignalInfo:
r"""Create a sox_signalinfo_t object. This object can be used to set the sample
rate, number of channels, length, bit precision and headroom multiplier
......@@ -235,12 +235,10 @@ def sox_signalinfo_t() -> SignalInfo:
>>> si.precision = 16
>>> si.length = 0
"""
from . import _torchaudio
return _torchaudio.sox_signalinfo_t()
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def sox_encodinginfo_t() -> EncodingInfo:
r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding
type, bit precision, compression factor, reverse bytes, reverse nibbles,
......@@ -270,8 +268,6 @@ def sox_encodinginfo_t() -> EncodingInfo:
>>> ei.opposite_endian = torchaudio.get_sox_bool(0)
"""
from . import _torchaudio
ei = _torchaudio.sox_encodinginfo_t()
sdo = get_sox_option_t(2) # sox_default_option
ei.reverse_bytes = sdo
......@@ -280,7 +276,7 @@ def sox_encodinginfo_t() -> EncodingInfo:
return ei
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def get_sox_encoding_t(i: int = None) -> EncodingInfo:
r"""Get enum of sox_encoding_t for sox encodings.
......@@ -291,8 +287,6 @@ def get_sox_encoding_t(i: int = None) -> EncodingInfo:
Returns:
sox_encoding_t: A sox_encoding_t type for output encoding
"""
from . import _torchaudio
if i is None:
# one can see all possible values using the .__members__ attribute
return _torchaudio.sox_encoding_t
......@@ -300,7 +294,7 @@ def get_sox_encoding_t(i: int = None) -> EncodingInfo:
return _torchaudio.sox_encoding_t(i)
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def get_sox_option_t(i: int = 2) -> Any:
r"""Get enum of sox_option_t for sox encodinginfo options.
......@@ -311,15 +305,13 @@ def get_sox_option_t(i: int = 2) -> Any:
Returns:
sox_option_t: A sox_option_t type
"""
from . import _torchaudio
if i is None:
return _torchaudio.sox_option_t
else:
return _torchaudio.sox_option_t(i)
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def get_sox_bool(i: int = 0) -> Any:
r"""Get enum of sox_bool for sox encodinginfo options.
......@@ -331,8 +323,6 @@ def get_sox_bool(i: int = 0) -> Any:
Returns:
sox_bool: A sox_bool type
"""
from . import _torchaudio
if i is None:
return _torchaudio.sox_bool
else:
......@@ -350,7 +340,7 @@ _SOX_SUCCESS_CODE = 0
# https://fossies.org/dox/sox-14.4.2/sox_8h.html#a8e07e80cebeff3339265d89c387cea93a9ef2b87ec303edfe40751d9a85fadeeb
@_audio_backend_guard("sox")
@_mod_utils.requires_module("torchaudio._torchaudio")
def initialize_sox() -> int:
"""Initialize sox for use with effects chains.
......@@ -370,7 +360,6 @@ def initialize_sox() -> int:
if _SOX_INITIALIZED is None:
raise RuntimeError('SoX effects chain has been already shut down. Can not initialize again.')
if not _SOX_INITIALIZED:
from . import _torchaudio
code = _torchaudio.initialize_sox()
if code == _SOX_SUCCESS_CODE:
_SOX_INITIALIZED = True
......@@ -379,7 +368,7 @@ def initialize_sox() -> int:
return _SOX_SUCCESS_CODE
@_audio_backend_guard("sox")
@_mod_utils.requires_module("torchaudio._torchaudio")
def shutdown_sox() -> int:
"""Showdown sox for effects chain.
......@@ -394,7 +383,6 @@ def shutdown_sox() -> int:
"""
global _SOX_INITIALIZED
if _SOX_INITIALIZED:
from . import _torchaudio
code = _torchaudio.shutdown_sox()
if code == _SOX_INITIALIZED:
_SOX_INITIALIZED = None
......
......@@ -45,22 +45,6 @@ def _get_audio_backend_module() -> Any:
return _audio_backends[backend]
def _audio_backend_guard(backends: Union[str, List[str]]) -> Any:
if isinstance(backends, str):
backends = [backends]
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if get_audio_backend() not in backends:
raise RuntimeError("Function {} requires backend to be one of {}.".format(func.__name__, backends))
return func(*args, **kwargs)
return wrapper
return decorator
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
......
import importlib.util
from functools import wraps
def _check_module_exists(*modules: str) -> bool:
def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
......@@ -9,3 +10,26 @@ def _check_module_exists(*modules: str) -> bool:
(see librosa/#747, torchvision/#544).
"""
return all(importlib.util.find_spec(m) is not None for m in modules)
def requires_module(*modules: str):
"""Decorate function to give error message if invoked without required optional modules.
This decorator is to give better error message to users rather
than raising ``NameError: name 'module' is not defined`` at random places.
"""
missing = [m for m in modules if not is_module_available(m)]
if not missing:
# fall through. If all the modules are available, no need to decorate
def decorator(func):
return func
else:
req = f'module: {missing[0]}' if len(missing) == 1 else f'modules: {missing}'
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires {req}')
return wrapped
return decorator
......@@ -3,10 +3,16 @@ from typing import Optional, Tuple
import torch
from torch import Tensor
import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
if _mod_utils.is_module_available('torchaudio._torchaudio'):
from . import _torchaudio
@_mod_utils.requires_module('torchaudio._torchaudio')
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: bool = True,
......@@ -35,7 +41,6 @@ def load(filepath: str,
if offset < 0:
raise ValueError("Expected positive offset value")
from . import _torchaudio
sample_rate = _torchaudio.read_audio_file(
filepath,
out,
......@@ -53,6 +58,7 @@ def load(filepath: str,
return out, sample_rate
@_mod_utils.requires_module('torchaudio._torchaudio')
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""
......@@ -65,8 +71,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
return torchaudio.save_encinfo(filepath, src, channels_first, si)
@_mod_utils.requires_module('torchaudio._torchaudio')
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""
from . import _torchaudio
return _torchaudio.get_info(filepath)
......@@ -5,11 +5,9 @@ from typing import Any, Callable, Iterable, Tuple
import torch
from torch import Tensor
from torchaudio.common_utils import _check_module_exists
from torchaudio._internal import module_utils as _mod_utils
_KALDI_IO_AVAILABLE = _check_module_exists('kaldi_io', 'numpy')
if _KALDI_IO_AVAILABLE:
if _mod_utils.is_module_available('kaldi_io', 'numpy'):
import numpy as np
import kaldi_io
......@@ -38,15 +36,13 @@ def _convert_method_output_to_tensor(file_or_fd: Any,
Returns:
Iterable[Tuple[str, Tensor]]: The string is the key and the tensor is vec/mat
"""
if not _KALDI_IO_AVAILABLE:
raise ImportError('Could not import kaldi_io. Did you install it?')
for key, np_arr in fn(file_or_fd):
if convert_contiguous:
np_arr = np.ascontiguousarray(np_arr)
yield key, torch.from_numpy(np_arr)
@_mod_utils.requires_module('kaldi_io', 'numpy')
def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<int>) tuples, which reads from the ark file/stream.
......@@ -66,6 +62,7 @@ def read_vec_int_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_int_ark, convert_contiguous=True)
@_mod_utils.requires_module('kaldi_io', 'numpy')
def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<float32/float64>) tuples, read according to Kaldi scp.
......@@ -82,6 +79,7 @@ def read_vec_flt_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_scp)
@_mod_utils.requires_module('kaldi_io', 'numpy')
def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,vector<float32/float64>) tuples, which reads from the ark file/stream.
......@@ -98,6 +96,7 @@ def read_vec_flt_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_vec_flt_ark)
@_mod_utils.requires_module('kaldi_io', 'numpy')
def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,matrix<float32/float64>) tuples, read according to Kaldi scp.
......@@ -114,6 +113,7 @@ def read_mat_scp(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
return _convert_method_output_to_tensor(file_or_fd, kaldi_io.read_mat_scp)
@_mod_utils.requires_module('kaldi_io', 'numpy')
def read_mat_ark(file_or_fd: Any) -> Iterable[Tuple[str, Tensor]]:
r"""Create generator of (key,matrix<float32/float64>) tuples, which reads from the ark file/stream.
......
......@@ -3,10 +3,14 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch
import torchaudio
from torch import Tensor
from torchaudio._backend import _audio_backend_guard
from torchaudio._internal import module_utils as _mod_utils
@_audio_backend_guard("sox")
if _mod_utils.is_module_available('torchaudio._torchaudio'):
from . import _torchaudio
@_mod_utils.requires_module('torchaudio._torchaudio')
def effect_names() -> List[str]:
"""Gets list of valid sox effect names
......@@ -15,12 +19,10 @@ def effect_names() -> List[str]:
Example
>>> EFFECT_NAMES = torchaudio.sox_effects.effect_names()
"""
from . import _torchaudio
return _torchaudio.get_effect_names()
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def SoxEffect():
r"""Create an object for passing sox effect information between python and c++
......@@ -28,8 +30,6 @@ def SoxEffect():
SoxEffect: An object with the following attributes: ename (str) which is the
name of effect, and eopts (List[str]) which is a list of effect options.
"""
from . import _torchaudio
return _torchaudio.SoxEffect()
......@@ -123,7 +123,7 @@ class SoxEffectsChain(object):
e.eopts = eargs
self.chain.append(e)
@_audio_backend_guard("sox")
@_mod_utils.requires_module('torchaudio._torchaudio')
def sox_build_flow_effects(self,
out: Optional[Tensor] = None) -> Tuple[Tensor, int]:
r"""Build effects chain and flow effects from input file to output tensor
......@@ -150,7 +150,6 @@ class SoxEffectsChain(object):
# print("effect options:", [x.eopts for x in self.chain])
torchaudio.initialize_sox()
from . import _torchaudio
sr = _torchaudio.build_flow_effects(self.input_file,
out,
self.channels_first,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment