Unverified Commit e9f19c35 authored by moto's avatar moto Committed by GitHub
Browse files

Refactor backend and not rely on global variables on switching (#698)

* Refactor backend switching

1. Do not rely on global variables for backend switch
   So that load/save/info/load_wav functions will be torchscript-able
2. Add no_backend module to for the case there is no backend module available
   [bonus] This allows the whole codebase importable on systems that do not have torchaudio C++ extension nor soundfile.
parent 87a761d6
import unittest
import torchaudio
from torchaudio._internal.module_utils import is_module_available
class BackendSwitch:
"""Test set/get_audio_backend works"""
backend = None
backend_module = None
def test_switch(self):
torchaudio.set_audio_backend(self.backend)
if self.backend is None:
assert torchaudio.get_audio_backend() is None
else:
assert torchaudio.get_audio_backend() == self.backend
assert torchaudio.load == self.backend_module.load
assert torchaudio.load_wav == self.backend_module.load_wav
assert torchaudio.save == self.backend_module.save
assert torchaudio.info == self.backend_module.info
class TestBackendSwitch_NoBackend(BackendSwitch, unittest.TestCase):
backend = None
backend_module = torchaudio.backend.no_backend
@unittest.skipIf(
not is_module_available('torchaudio._torchaudio'),
'torchaudio C++ extension not available')
class TestBackendSwitch_SoX(BackendSwitch, unittest.TestCase):
backend = 'sox'
backend_module = torchaudio.backend.sox_backend
@unittest.skipIf(not is_module_available('soundfile'), '"soundfile" not available')
class TestBackendSwitch_soundfile(BackendSwitch, unittest.TestCase):
backend = 'soundfile'
backend_module = torchaudio.backend.soundfile_backend
......@@ -29,11 +29,13 @@ class TestDatasets(unittest.TestCase):
data[0]
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_commonvoice(self):
data = COMMONVOICE(self.path, url="tatar")
data[0]
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_commonvoice_diskcache(self):
data = COMMONVOICE(self.path, url="tatar")
data = diskcache_iterator(data)
......@@ -43,6 +45,7 @@ class TestDatasets(unittest.TestCase):
data[0]
@unittest.skipIf("sox" not in common_utils.BACKENDS, "sox not available")
@common_utils.AudioBackendScope('sox')
def test_commonvoice_bg(self):
data = COMMONVOICE(self.path, url="tatar")
data = bg_iterator(data, 5)
......
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils
from torchaudio import (
compliance,
......@@ -11,7 +7,6 @@ from torchaudio import (
transforms
)
from torchaudio.backend import (
_get_audio_backend_module,
list_audio_backends,
get_audio_backend,
set_audio_backend,
......@@ -57,116 +52,3 @@ def shutdown_sox():
This function is deprecated. See ``torchaudio.sox_effects.shutdown_sox_effects``
"""
_shutdown_sox_effects()
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
r"""Loads an audio file from disk into a tensor
Args:
filepath (str or Path): Path to audio file
out (Tensor, optional): An output tensor to use instead of creating one. (Default: ``None``)
normalization (bool, float, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes signed 32-bit audio), and normalizes to `[-1, 1]`.
If `float`, then output is divided by that number
If `Callable`, then the output is passed as a parameter
to the given function, then the output is divided by
the result. (Default: ``True``)
channels_first (bool, optional): Set channels first or length first in result. (Default: ``True``)
num_frames (int, optional): Number of frames to load. 0 to load everything after the offset.
(Default: ``0``)
offset (int, optional): Number of frames from the start of the file to begin data loading.
(Default: ``0``)
signalinfo (sox_signalinfo_t, optional): A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``)
encodinginfo (sox_encodinginfo_t, optional): A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``)
filetype (str, optional): A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``)
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
Example
>>> data, sample_rate = torchaudio.load('foo.mp3')
>>> print(data.size())
torch.Size([2, 278756])
>>> print(sample_rate)
44100
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
>>> print(data_vol_normalized.abs().max())
1.
"""
return _get_audio_backend_module().load(
filepath,
out=out,
normalization=normalization,
channels_first=channels_first,
num_frames=num_frames,
offset=offset,
signalinfo=signalinfo,
encodinginfo=encodinginfo,
filetype=filetype,
)
def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
r""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
the input right by 16 bits.
Args:
filepath (str or Path): Path to audio file
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
"""
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""Convenience function for `save_encinfo`.
Args:
filepath (str): Path to audio file
src (Tensor): An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate (int): An integer which is the sample rate of the
audio (as listed in the metadata of the file)
precision (int, optional): Bit precision (Default: ``16``)
channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``)
"""
return _get_audio_backend_module().save(
filepath, src, sample_rate, precision=precision, channels_first=channels_first
)
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""Gets metadata from an audio file without loading the signal.
Args:
filepath (str): Path to audio file
Returns:
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info
Example
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""
return _get_audio_backend_module().info(filepath)
from . import utils
from .utils import (
_get_audio_backend_module,
list_audio_backends,
get_audio_backend,
set_audio_backend,
......
from typing import Any, Optional, Tuple
from typing import Any, Optional
class SignalInfo:
......@@ -29,3 +29,115 @@ class EncodingInfo:
self.reverse_nibbles = reverse_nibbles
self.reverse_bits = reverse_bits
self.opposite_endian = opposite_endian
_LOAD_DOCSTRING = r"""Loads an audio file from disk into a tensor
Args:
filepath: Path to audio file
out: An optional output tensor to use instead of creating one. (Default: ``None``)
normalization: Optional normalization.
If boolean `True`, then output is divided by `1 << 31`.
Assuming the input is signed 32-bit audio, this normalizes to `[-1, 1]`.
If `float`, then output is divided by that number.
If `Callable`, then the output is passed as a paramete to the given function,
then the output is divided by the result. (Default: ``True``)
channels_first: Set channels first or length first in result. (Default: ``True``)
num_frames: Number of frames to load. 0 to load everything after the offset.
(Default: ``0``)
offset: Number of frames from the start of the file to begin data loading.
(Default: ``0``)
signalinfo: A sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determined. (Default: ``None``)
encodinginfo: A sox_encodinginfo_t type, which could be set if the
audio type cannot be automatically determined. (Default: ``None``)
filetype: A filetype or extension to be set if sox cannot determine it
automatically. (Default: ``None``)
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where
L is the number of audio frames and
C is the number of channels.
An integer which is the sample rate of the audio (as listed in the metadata of the file)
Example
>>> data, sample_rate = torchaudio.load('foo.mp3')
>>> print(data.size())
torch.Size([2, 278756])
>>> print(sample_rate)
44100
>>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
>>> print(data_vol_normalized.abs().max())
1.
"""
_LOAD_WAV_DOCSTRING = r""" Loads a wave file.
It assumes that the wav file uses 16 bit per sample that needs normalization by
shifting the input right by 16 bits.
Args:
filepath: Path to audio file
Returns:
(Tensor, int): An output tensor of size `[C x L]` or `[L x C]` where L is the number
of audio frames and C is the number of channels. An integer which is the sample rate of the
audio (as listed in the metadata of the file)
"""
_SAVE_DOCSTRING = r"""Saves a Tensor on file as an audio file
Args:
filepath: Path to audio file
src: An input 2D tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels
sample_rate: An integer which is the sample rate of the
audio (as listed in the metadata of the file)
precision Bit precision (Default: ``16``)
channels_first (bool, optional): Set channels first or length first in result. (
Default: ``True``)
"""
_INFO_DOCSTRING = r"""Gets metadata from an audio file without loading the signal.
Args:
filepath: Path to audio file
Returns:
(sox_signalinfo_t, sox_encodinginfo_t): A si (sox_signalinfo_t) signal
info as a python object. An ei (sox_encodinginfo_t) encoding info
Example
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""
def _impl_load(func):
setattr(func, '__doc__', _LOAD_DOCSTRING)
return func
def _impl_load_wav(func):
setattr(func, '__doc__', _LOAD_WAV_DOCSTRING)
return func
def _impl_save(func):
setattr(func, '__doc__', _SAVE_DOCSTRING)
return func
def _impl_info(func):
setattr(func, '__doc__', _INFO_DOCSTRING)
return func
from pathlib import Path
from typing import Any, Callable, Optional, Tuple, Union
from torch import Tensor
from . import common
from .common import SignalInfo, EncodingInfo
@common._impl_load
def load(filepath: Union[str, Path],
out: Optional[Tensor] = None,
normalization: Union[bool, float, Callable] = True,
channels_first: bool = True,
num_frames: int = 0,
offset: int = 0,
signalinfo: Optional[SignalInfo] = None,
encodinginfo: Optional[EncodingInfo] = None,
filetype: Optional[str] = None) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')
@common._impl_load_wav
def load_wav(filepath: Union[str, Path], **kwargs: Any) -> Tuple[Tensor, int]:
raise RuntimeError('No audio I/O backend is available.')
@common._impl_save
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
raise RuntimeError('No audio I/O backend is available.')
@common._impl_info
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
raise RuntimeError('No audio I/O backend is available.')
......@@ -8,6 +8,7 @@ from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from . import common
from .common import SignalInfo, EncodingInfo
if _mod_utils.is_module_available('soundfile'):
......@@ -24,6 +25,7 @@ _subtype_to_precision = {
@_mod_utils.requires_module('soundfile')
@common._impl_load
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
......@@ -71,6 +73,14 @@ def load(filepath: str,
@_mod_utils.requires_module('soundfile')
@common._impl_load_wav
def load_wav(filepath, **kwargs):
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)
@_mod_utils.requires_module('soundfile')
@common._impl_save
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""
......@@ -104,6 +114,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
@_mod_utils.requires_module('soundfile')
@common._impl_info
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""
......
......@@ -8,6 +8,7 @@ from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from . import common
from .common import SignalInfo, EncodingInfo
if _mod_utils.is_module_available('torchaudio._torchaudio'):
......@@ -15,6 +16,7 @@ if _mod_utils.is_module_available('torchaudio._torchaudio'):
@_mod_utils.requires_module('torchaudio._torchaudio')
@common._impl_load
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: bool = True,
......@@ -61,6 +63,14 @@ def load(filepath: str,
@_mod_utils.requires_module('torchaudio._torchaudio')
@common._impl_load_wav
def load_wav(filepath, **kwargs):
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)
@_mod_utils.requires_module('torchaudio._torchaudio')
@common._impl_save
def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, channels_first: bool = True) -> None:
r"""See torchaudio.save"""
......@@ -74,6 +84,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
@_mod_utils.requires_module('torchaudio._torchaudio')
@common._impl_info
def info(filepath: str) -> Tuple[SignalInfo, EncodingInfo]:
r"""See torchaudio.info"""
return _torchaudio.get_info(filepath)
......
"""Defines utilities for switching audio backends"""
import warnings
from typing import Any, Optional
from typing import Optional, List
from torchaudio._internal import module_utils as _mod_utils
from . import soundfile_backend, sox_backend
import torchaudio
from torchaudio._internal.module_utils import is_module_available
from . import (
no_backend,
sox_backend,
soundfile_backend,
)
_BACKEND = None
_BACKENDS = {}
__all__ = [
'list_audio_backends',
'get_audio_backend',
'set_audio_backend',
]
def list_audio_backends():
return list(_BACKENDS.keys())
def list_audio_backends() -> List[str]:
"""List available backends"""
backends = []
if is_module_available('soundfile'):
backends.append('soundfile')
if is_module_available('torchaudio._torchaudio'):
backends.append('sox')
return backends
def set_audio_backend(backend: str) -> None:
"""
Specifies the package used to load.
def set_audio_backend(backend: Optional[str]) -> None:
"""Set the backend for I/O operation
Args:
backend (str): Name of the backend. One of "sox" or "soundfile",
based on availability of the system.
"""
if backend not in _BACKENDS:
if backend is not None and backend not in list_audio_backends():
raise RuntimeError(
f'Backend "{backend}" is not one of '
f'available backends: {list_audio_backends()}.')
global _BACKEND
_BACKEND = backend
def get_audio_backend() -> Optional[str]:
"""
Gets the name of the package used to load.
"""
return _BACKEND
if backend is None:
module = no_backend
elif backend == 'sox':
module = sox_backend
elif backend == 'soundfile':
module = soundfile_backend
else:
raise NotImplementedError(f'Unexpected backend "{backend}"')
def _get_audio_backend_module() -> Any:
"""
Gets the module backend to load.
"""
if _BACKEND is None:
raise RuntimeError('Backend is not initialized.')
return _BACKENDS[_BACKEND]
for func in ['save', 'load', 'load_wav', 'info']:
setattr(torchaudio, func, getattr(module, func))
def _init_audio_backend():
global _BACKEND
global _BACKENDS
_BACKENDS = {}
if _mod_utils.is_module_available('soundfile'):
_BACKENDS['soundfile'] = soundfile_backend
if _mod_utils.is_module_available('torchaudio._torchaudio'):
_BACKENDS['sox'] = sox_backend
if 'sox' in _BACKENDS:
_BACKEND = 'sox'
elif 'soundfile' in _BACKENDS:
_BACKEND = 'soundfile'
backends = list_audio_backends()
if 'sox' in backends:
set_audio_backend('sox')
elif 'soundfile' in backends:
set_audio_backend('soundfile')
else:
warnings.warn('No audio backend is available.')
_BACKEND = None
set_audio_backend(None)
def get_audio_backend() -> Optional[str]:
"""Get the name of the current backend"""
if torchaudio.load == no_backend.load:
return None
if torchaudio.load == sox_backend.load:
return 'sox'
if torchaudio.load == soundfile_backend.load:
return 'soundfile'
raise ValueError('Unknown backend.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment