Unverified Commit 2fd32dd0 authored by moto's avatar moto Committed by GitHub
Browse files

List backends dynamically based on availability (#697)

parent e3247e30
...@@ -10,7 +10,7 @@ from torch.testing._internal.common_utils import TestCase ...@@ -10,7 +10,7 @@ from torch.testing._internal.common_utils import TestCase
import torchaudio import torchaudio
_TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__)) _TEST_DIR_PATH = os.path.dirname(os.path.realpath(__file__))
BACKENDS = torchaudio._backend._audio_backends BACKENDS = torchaudio._backend._BACKENDS
def get_asset_path(*paths): def get_asset_path(*paths):
......
from typing import Any from typing import Any, Optional
import platform from torchaudio._internal import module_utils as _mod_utils
from . import _soundfile_backend, _sox_backend
_BACKEND = None
_BACKENDS = {}
from . import _soundfile_backend, _sox_backend if _mod_utils.is_module_available('soundfile'):
_BACKENDS['soundfile'] = _soundfile_backend
if _mod_utils.is_module_available('torchaudio._torchaudio'):
_BACKENDS['sox'] = _sox_backend
if 'sox' in _BACKENDS:
_BACKEND = 'sox'
elif 'soundfile' in _BACKENDS:
_BACKEND = 'soundfile'
if platform.system() == "Windows":
_audio_backend = "soundfile" def list_audio_backends():
_audio_backends = {"soundfile": _soundfile_backend} return list(_BACKENDS.keys())
else:
_audio_backend = "sox"
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}
def set_audio_backend(backend: str) -> None: def set_audio_backend(backend: str) -> None:
""" """
Specifies the package used to load. Specifies the package used to load.
Args: Args:
backend (str): Name of the backend. One of {}. backend (str): Name of the backend. One of "sox" or "soundfile",
""".format(_audio_backends.keys()) based on availability of the system.
global _audio_backend """
if backend not in _audio_backends: if backend not in _BACKENDS:
raise ValueError( raise RuntimeError(
"Invalid backend '{}'. Options are {}.".format(backend, _audio_backends.keys()) f'Backend "{backend}" is not one of '
) f'available backends: {list_audio_backends()}.')
_audio_backend = backend global _BACKEND
_BACKEND = backend
def get_audio_backend() -> str: def get_audio_backend() -> Optional[str]:
""" """
Gets the name of the package used to load. Gets the name of the package used to load.
""" """
return _audio_backend return _BACKEND
def _get_audio_backend_module() -> Any: def _get_audio_backend_module() -> Any:
""" """
Gets the module backend to load. Gets the module backend to load.
""" """
backend = get_audio_backend() if _BACKEND is None:
return _audio_backends[backend] raise RuntimeError('Backend is not initialized.')
return _BACKENDS[_BACKEND]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment