Unverified Commit 2fd32dd0 authored by moto's avatar moto Committed by GitHub
Browse files

List backends dynamically based on availability (#697)

parent e3247e30
......@@ -10,7 +10,7 @@ from torch.testing._internal.common_utils import TestCase
import torchaudio
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends
BACKENDS = torchaudio._backend._BACKENDS
def get_asset_path(*paths):
......
from typing import Any
from typing import Any, Optional
import platform
from torchaudio._internal import module_utils as _mod_utils
from . import _soundfile_backend, _sox_backend
_BACKEND = None
_BACKENDS = {}
from . import _soundfile_backend, _sox_backend
if _mod_utils.is_module_available('soundfile'):
_BACKENDS['soundfile'] = _soundfile_backend
if _mod_utils.is_module_available('torchaudio._torchaudio'):
_BACKENDS['sox'] = _sox_backend
if 'sox' in _BACKENDS:
_BACKEND = 'sox'
elif 'soundfile' in _BACKENDS:
_BACKEND = 'soundfile'
if platform.system() == "Windows":
_audio_backend = "soundfile"
_audio_backends = {"soundfile": _soundfile_backend}
else:
_audio_backend = "sox"
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}
def list_audio_backends():
return list(_BACKENDS.keys())
def set_audio_backend(backend: str) -> None:
"""
Specifies the package used to load.
Args:
backend (str): Name of the backend. One of {}.
""".format(_audio_backends.keys())
global _audio_backend
if backend not in _audio_backends:
raise ValueError(
"Invalid backend '{}'. Options are {}.".format(backend, _audio_backends.keys())
)
_audio_backend = backend
backend (str): Name of the backend. One of "sox" or "soundfile",
based on availability of the system.
"""
if backend not in _BACKENDS:
raise RuntimeError(
f'Backend "{backend}" is not one of '
f'available backends: {list_audio_backends()}.')
global _BACKEND
_BACKEND = backend
def get_audio_backend() -> str:
def get_audio_backend() -> Optional[str]:
"""
Gets the name of the package used to load.
"""
return _audio_backend
return _BACKEND
def _get_audio_backend_module() -> Any:
"""
Gets the module backend to load.
"""
backend = get_audio_backend()
return _audio_backends[backend]
if _BACKEND is None:
raise RuntimeError('Backend is not initialized.')
return _BACKENDS[_BACKEND]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment