Unverified Commit d1ce29a0 authored by bowangbj's avatar bowangbj Committed by GitHub
Browse files

Check if libsndfile is available by importing soundfile (#1718)

`torchaudio` treats `soundfile` as an optional dependency, and it assumes either `soundfile` is properly installed or it is not installed at all. However, there is a third state where `soundfile` is installed but the backing library `libsndfile` is not installed, and in this case, `import torchaudio` fails.

This commit resolves it by further checking if `soundfile` is importable. 

See also: https://github.com/pytorch/audio/issues/1687
parent 2c115821
......@@ -77,6 +77,37 @@ def requires_kaldi():
return decorator
def _check_soundfile_importable():
if not is_module_available('soundfile'):
return False
try:
import soundfile # noqa: F401
return True
except Exception:
warnings.warn("Failed to import soundfile. 'soundfile' backend is not available.")
return False
_is_soundfile_importable = _check_soundfile_importable()
def is_soundfile_available():
return _is_soundfile_importable
def requires_soundfile():
if is_soundfile_available():
def decorator(func):
return func
else:
def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f'{func.__module__}.{func.__name__} requires soundfile')
return wrapped
return decorator
def is_sox_available():
return is_module_available('torchaudio._torchaudio') and torch.ops.torchaudio.is_sox_available()
......
......@@ -7,10 +7,9 @@ from torchaudio._internal import module_utils as _mod_utils
from .common import AudioMetaData
if _mod_utils.is_module_available("soundfile"):
if _mod_utils.is_soundfile_available():
import soundfile
# Mapping from soundfile subtype to number of bits per sample.
# This is mostly heuristical and the value is set to 0 when it is irrelevant
# (lossy formats) or when it can't be inferred.
......@@ -81,7 +80,7 @@ def _get_encoding(format: str, subtype: str):
return _SUBTYPE_TO_ENCODING.get(subtype, 'UNKNOWN')
@_mod_utils.requires_module("soundfile")
@_mod_utils.requires_soundfile()
def info(filepath: str, format: Optional[str] = None) -> AudioMetaData:
"""Get signal information of an audio file.
......@@ -120,7 +119,7 @@ _SUBTYPE2DTYPE = {
}
@_mod_utils.requires_module("soundfile")
@_mod_utils.requires_soundfile()
def load(
filepath: str,
frame_offset: int = 0,
......@@ -299,7 +298,7 @@ def _get_subtype(
raise ValueError(f"Unsupported format: {format}")
@_mod_utils.requires_module("soundfile")
@_mod_utils.requires_soundfile()
def save(
filepath: str,
src: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment