Commit 6fb21ab1 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Move backend initialization to toplevel (#3548)

Summary:
The backend dispatcher is implemented in `torchaudio._backend`, while the legacy backend is implemented in `torchaudio.backend`.

The initialization happen in `torchaudio._backend`.
This commit moves it to `torchaudio.__init__`, so that `backend` and `_backend` is more independent.

Pull Request resolved: https://github.com/pytorch/audio/pull/3548

Reviewed By: huangruizhe

Differential Revision: D48219244

Pulled By: mthrok

fbshipit-source-id: e694cb232794f90902a60ee51c7bf11b7f0548a0
parent 2d1138c5
from unittest.mock import patch
import torchaudio import torchaudio
from torchaudio_unittest import common_utils from torchaudio_unittest import common_utils
...@@ -10,13 +8,12 @@ class BackendSwitchMixin: ...@@ -10,13 +8,12 @@ class BackendSwitchMixin:
backend = None backend = None
backend_module = None backend_module = None
@patch("torchaudio.backend.utils._is_backend_dispatcher_enabled", lambda: False)
def test_switch(self): def test_switch(self):
torchaudio.set_audio_backend(self.backend) torchaudio.backend.utils.set_audio_backend(self.backend)
if self.backend is None: if self.backend is None:
assert torchaudio.get_audio_backend() is None assert torchaudio.backend.utils.get_audio_backend() is None
else: else:
assert torchaudio.get_audio_backend() == self.backend assert torchaudio.backend.utils.get_audio_backend() == self.backend
assert torchaudio.load == self.backend_module.load assert torchaudio.load == self.backend_module.load
assert torchaudio.save == self.backend_module.save assert torchaudio.save == self.backend_module.save
assert torchaudio.info == self.backend_module.info assert torchaudio.info == self.backend_module.info
......
...@@ -12,13 +12,27 @@ from torchaudio import ( # noqa: F401 ...@@ -12,13 +12,27 @@ from torchaudio import ( # noqa: F401
utils, utils,
) )
from torchaudio.backend import get_audio_backend, list_audio_backends, set_audio_backend
try: try:
from .version import __version__, git_version # noqa: F401 from .version import __version__, git_version # noqa: F401
except ImportError: except ImportError:
pass pass
def _is_backend_dispatcher_enabled():
import os
return os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1"
if _is_backend_dispatcher_enabled():
from ._backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
else:
from .backend import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
_init_backend()
__all__ = [ __all__ = [
"io", "io",
"compliance", "compliance",
......
from .utils import get_info_func, get_load_func, get_save_func import warnings
from typing import List, Optional
import torchaudio
info = get_info_func() from . import utils
load = get_load_func()
save = get_save_func()
# TODO: Once legacy global backend is removed, move this to torchaudio.__init__
def _init_backend():
torchaudio.info = utils.get_info_func()
torchaudio.load = utils.get_load_func()
torchaudio.save = utils.get_save_func()
def list_audio_backends() -> List[str]:
return list(utils.get_available_backends().keys())
# Temporary until global backend is removed
def get_audio_backend() -> Optional[str]:
warnings.warn("I/O Dispatcher is enabled. There is no global audio backend.", stacklevel=2)
return None
# Temporary until global backend is removed
def set_audio_backend(_: Optional[str]):
warnings.warn("I/O Dispatcher is enabled. set_audio_backend is a no-op", stacklevel=2)
# flake8: noqa from .utils import _init_backend, get_audio_backend, list_audio_backends, set_audio_backend
import torchaudio
from . import utils
from .utils import _is_backend_dispatcher_enabled, get_audio_backend, list_audio_backends, set_audio_backend
if _is_backend_dispatcher_enabled(): __all__ = ["_init_backend", "get_audio_backend", "list_audio_backends", "set_audio_backend"]
from torchaudio._backend.utils import get_info_func, get_load_func, get_save_func
torchaudio.info = get_info_func()
torchaudio.load = get_load_func()
torchaudio.save = get_save_func()
else:
utils._init_audio_backend()
"""Defines utilities for switching audio backends""" """Defines utilities for switching audio backends"""
import os
import warnings import warnings
from typing import List, Optional from typing import List, Optional
...@@ -15,19 +14,12 @@ __all__ = [ ...@@ -15,19 +14,12 @@ __all__ = [
] ]
def _is_backend_dispatcher_enabled() -> bool:
return os.getenv("TORCHAUDIO_USE_BACKEND_DISPATCHER", default="1") == "1"
def list_audio_backends() -> List[str]: def list_audio_backends() -> List[str]:
"""List available backends """List available backends
Returns: Returns:
List[str]: The list of available backends. List[str]: The list of available backends.
""" """
if _is_backend_dispatcher_enabled():
warnings.warn("list_audio_backend's return value is irrelevant when the I/O backend dispatcher is enabled.")
backends = [] backends = []
if _mod_utils.is_module_available("soundfile"): if _mod_utils.is_module_available("soundfile"):
backends.append("soundfile") backends.append("soundfile")
...@@ -44,10 +36,6 @@ def set_audio_backend(backend: Optional[str]): ...@@ -44,10 +36,6 @@ def set_audio_backend(backend: Optional[str]):
One of ``"sox_io"`` or ``"soundfile"`` based on availability One of ``"sox_io"`` or ``"soundfile"`` based on availability
of the system. If ``None`` is provided the current backend is unassigned. of the system. If ``None`` is provided the current backend is unassigned.
""" """
if _is_backend_dispatcher_enabled():
warnings.warn("set_audio_backend is a no-op when the I/O backend dispatcher is enabled.")
return
if backend is not None and backend not in list_audio_backends(): if backend is not None and backend not in list_audio_backends():
raise RuntimeError(f'Backend "{backend}" is not one of ' f"available backends: {list_audio_backends()}.") raise RuntimeError(f'Backend "{backend}" is not one of ' f"available backends: {list_audio_backends()}.")
...@@ -64,7 +52,7 @@ def set_audio_backend(backend: Optional[str]): ...@@ -64,7 +52,7 @@ def set_audio_backend(backend: Optional[str]):
setattr(torchaudio, func, getattr(module, func)) setattr(torchaudio, func, getattr(module, func))
def _init_audio_backend(): def _init_backend():
backends = list_audio_backends() backends = list_audio_backends()
if "sox_io" in backends: if "sox_io" in backends:
set_audio_backend("sox_io") set_audio_backend("sox_io")
...@@ -81,9 +69,6 @@ def get_audio_backend() -> Optional[str]: ...@@ -81,9 +69,6 @@ def get_audio_backend() -> Optional[str]:
Returns: Returns:
Optional[str]: The name of the current backend or ``None`` if no backend is assigned. Optional[str]: The name of the current backend or ``None`` if no backend is assigned.
""" """
if _is_backend_dispatcher_enabled():
warnings.warn("get_audio_backend's return value is irrelevant when the I/O backend dispatcher is enabled.")
if torchaudio.load == no_backend.load: if torchaudio.load == no_backend.load:
return None return None
if torchaudio.load == sox_io_backend.load: if torchaudio.load == sox_io_backend.load:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment