Unverified Commit 774ebc78 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Backend switch (#355)

* move sox inside function calls.

* add backend switch mechanism.

* import sox at runtime, not import.

* add backend list.

* backend tests.

* creating hidden modules for backend.

* naming backend same as file: soundfile.

* remove docstring in backend file.

* test soundfile info.

* soundfile doesn't support int64.

* adding test for wav file.

* error with incorrect parameter instead of silent ignore.

* adding test across backend. using float32 as done in sox.

* backend guard decorator.
parent 4887ff41
......@@ -7,14 +7,41 @@ import math
import os
class AudioBackendScope:
def __init__(self, backend):
self.new_backend = backend
self.previous_backend = torchaudio.get_audio_backend()
def __enter__(self):
torchaudio.set_audio_backend(self.new_backend)
return self.new_backend
def __exit__(self, type, value, traceback):
backend = self.previous_backend
torchaudio.set_audio_backend(backend)
class Test_LoadSave(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
test_filepath_wav = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.wav")
def test_1_save(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save(self.test_filepath, False)
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save(self.test_filepath_wav, True)
def _test_1_save(self, test_filepath, normalization):
# load signal
x, sr = torchaudio.load(self.test_filepath, normalization=False)
x, sr = torchaudio.load(test_filepath, normalization=normalization)
# check save
new_filepath = os.path.join(self.test_dirpath, "test.wav")
......@@ -52,6 +79,14 @@ class Test_LoadSave(unittest.TestCase):
"test.wav")
torchaudio.save(new_filepath, x, sr)
def test_1_save_sine(self):
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save_sine()
def _test_1_save_sine(self):
# save created file
sinewave_filepath = os.path.join(self.test_dirpath, "assets",
"sinewave.wav")
......@@ -78,34 +113,36 @@ class Test_LoadSave(unittest.TestCase):
os.unlink(new_filepath)
def test_2_load(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load(self.test_filepath, 278756)
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load(self.test_filepath_wav, 276858)
def _test_2_load(self, test_filepath, length):
# check normal loading
x, sr = torchaudio.load(self.test_filepath)
x, sr = torchaudio.load(test_filepath)
self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (2, 278756))
# check no normalizing
x, _ = torchaudio.load(self.test_filepath, normalization=False)
self.assertTrue(x.min() <= -1.0)
self.assertTrue(x.max() >= 1.0)
self.assertEqual(x.size(), (2, length))
# check offset
offset = 15
x, _ = torchaudio.load(self.test_filepath)
x_offset, _ = torchaudio.load(self.test_filepath, offset=offset)
x, _ = torchaudio.load(test_filepath)
x_offset, _ = torchaudio.load(test_filepath, offset=offset)
self.assertTrue(x[:, offset:].allclose(x_offset))
# check number of frames
n = 201
x, _ = torchaudio.load(self.test_filepath, num_frames=n)
x, _ = torchaudio.load(test_filepath, num_frames=n)
self.assertTrue(x.size(), (2, n))
# check channels first
x, _ = torchaudio.load(self.test_filepath, channels_first=False)
self.assertEqual(x.size(), (278756, 2))
# check different input tensor type
x, _ = torchaudio.load(self.test_filepath, torch.LongTensor(), normalization=False)
self.assertTrue(isinstance(x, torch.LongTensor))
x, _ = torchaudio.load(test_filepath, channels_first=False)
self.assertEqual(x.size(), (length, 2))
# check raising errors
with self.assertRaises(OSError):
......@@ -116,7 +153,30 @@ class Test_LoadSave(unittest.TestCase):
os.path.dirname(self.test_dirpath), "torchaudio")
torchaudio.load(tdir)
def test_2_load_nonormalization(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load_nonormalization(self.test_filepath, 278756)
def _test_2_load_nonormalization(self, test_filepath, length):
# check no normalizing
x, _ = torchaudio.load(test_filepath, normalization=False)
self.assertTrue(x.min() <= -1.0)
self.assertTrue(x.max() >= 1.0)
# check different input tensor type
x, _ = torchaudio.load(test_filepath, torch.LongTensor(), normalization=False)
self.assertTrue(isinstance(x, torch.LongTensor))
def test_3_load_and_save_is_identity(self):
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_3_load_and_save_is_identity()
def _test_3_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
......@@ -126,7 +186,35 @@ class Test_LoadSave(unittest.TestCase):
self.assertEqual(sample_rate, sample_rate2)
os.unlink(output_path)
def test_3_load_and_save_is_identity_across_backend(self):
with self.subTest():
self._test_3_load_and_save_is_identity_across_backend("sox", "soundfile")
with self.subTest():
self._test_3_load_and_save_is_identity_across_backend("soundfile", "sox")
def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2):
with AudioBackendScope(backend1):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor1, sample_rate1 = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
torchaudio.save(output_path, tensor1, sample_rate1)
with AudioBackendScope(backend2):
tensor2, sample_rate2 = torchaudio.load(output_path)
self.assertTrue(tensor1.allclose(tensor2))
self.assertEqual(sample_rate1, sample_rate2)
os.unlink(output_path)
def test_4_load_partial(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_4_load_partial()
def _test_4_load_partial(self):
num_frames = 101
offset = 201
# load entire mono sinewave wav file, load a partial copy and then compare
......@@ -163,6 +251,12 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.load(input_sine_path, offset=100000)
def test_5_get_info(self):
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_5_get_info()
def _test_5_get_info(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
channels, samples, rate, precision = (1, 64000, 16000, 16)
si, ei = torchaudio.info(input_path)
......
......@@ -2,9 +2,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os.path
import torch
import _torch_sox
from torchaudio import transforms, datasets, kaldi_io, sox_effects, compliance
from torchaudio import (
compliance,
datasets,
kaldi_io,
sox_effects,
transforms,
)
from torchaudio._backend import (
check_input,
_audio_backend_guard,
_get_audio_backend_module,
get_audio_backend,
set_audio_backend,
)
try:
from .version import __version__, git_version # noqa: F401
......@@ -12,13 +25,6 @@ except ImportError:
pass
def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filepath,
out=None,
normalization=True,
......@@ -67,36 +73,18 @@ def load(filepath,
1.
"""
# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
# initialize output tensor
if out is not None:
check_input(out)
else:
out = torch.FloatTensor()
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if offset < 0:
raise ValueError("Expected positive offset value")
sample_rate = _torch_sox.read_audio_file(filepath,
out,
channels_first,
num_frames,
offset,
signalinfo,
encodinginfo,
filetype)
# normalize if needed
_audio_normalization(out, normalization)
return out, sample_rate
return getattr(_get_audio_backend_module(), 'load')(
filepath,
out=out,
normalization=normalization,
channels_first=channels_first,
num_frames=num_frames,
offset=offset,
signalinfo=signalinfo,
encodinginfo=encodinginfo,
filetype=filetype,
)
def load_wav(filepath, **kwargs):
......@@ -128,15 +116,13 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
channels_first (bool): Set channels first or length first in result. (
Default: ``True``)
"""
si = sox_signalinfo_t()
ch_idx = 0 if channels_first else 1
si.rate = sample_rate
si.channels = 1 if src.dim() == 1 else src.size(ch_idx)
si.length = src.numel()
si.precision = precision
return save_encinfo(filepath, src, channels_first, si)
return getattr(_get_audio_backend_module(), 'save')(
filepath, src, sample_rate, precision=precision, channels_first=channels_first
)
@_audio_backend_guard("sox")
def save_encinfo(filepath,
src,
channels_first=True,
......@@ -203,6 +189,8 @@ def save_encinfo(filepath,
src = src.transpose(1, 0)
# save data to file
src = src.contiguous()
import _torch_sox
_torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
......@@ -220,9 +208,11 @@ def info(filepath):
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""
return _torch_sox.get_info(filepath)
return getattr(_get_audio_backend_module(), 'info')(filepath)
@_audio_backend_guard("sox")
def sox_signalinfo_t():
r"""Create a sox_signalinfo_t object. This object can be used to set the sample
rate, number of channels, length, bit precision and headroom multiplier
......@@ -242,9 +232,12 @@ def sox_signalinfo_t():
>>> si.precision = 16
>>> si.length = 0
"""
import _torch_sox
return _torch_sox.sox_signalinfo_t()
@_audio_backend_guard("sox")
def sox_encodinginfo_t():
r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding
type, bit precision, compression factor, reverse bytes, reverse nibbles,
......@@ -274,6 +267,8 @@ def sox_encodinginfo_t():
>>> ei.opposite_endian = torchaudio.get_sox_bool(0)
"""
import _torch_sox
ei = _torch_sox.sox_encodinginfo_t()
sdo = get_sox_option_t(2) # sox_default_option
ei.reverse_bytes = sdo
......@@ -282,6 +277,7 @@ def sox_encodinginfo_t():
return ei
@_audio_backend_guard("sox")
def get_sox_encoding_t(i=None):
r"""Get enum of sox_encoding_t for sox encodings.
......@@ -292,6 +288,8 @@ def get_sox_encoding_t(i=None):
Returns:
sox_encoding_t: A sox_encoding_t type for output encoding
"""
import _torch_sox
if i is None:
# one can see all possible values using the .__members__ attribute
return _torch_sox.sox_encoding_t
......@@ -299,6 +297,7 @@ def get_sox_encoding_t(i=None):
return _torch_sox.sox_encoding_t(i)
@_audio_backend_guard("sox")
def get_sox_option_t(i=2):
r"""Get enum of sox_option_t for sox encodinginfo options.
......@@ -309,12 +308,15 @@ def get_sox_option_t(i=2):
Returns:
sox_option_t: A sox_option_t type
"""
import _torch_sox
if i is None:
return _torch_sox.sox_option_t
else:
return _torch_sox.sox_option_t(i)
@_audio_backend_guard("sox")
def get_sox_bool(i=0):
r"""Get enum of sox_bool for sox encodinginfo options.
......@@ -326,24 +328,32 @@ def get_sox_bool(i=0):
Returns:
sox_bool: A sox_bool type
"""
import _torch_sox
if i is None:
return _torch_sox.sox_bool
else:
return _torch_sox.sox_bool(i)
@_audio_backend_guard("sox")
def initialize_sox():
"""Initialize sox for use with effects chains. This is not required for simple
loading. Importantly, only run `initialize_sox` once and do not shutdown
after each effect chain, but rather once you are finished with all effects chains.
"""
import _torch_sox
return _torch_sox.initialize_sox()
@_audio_backend_guard("sox")
def shutdown_sox():
"""Showdown sox for effects chain. Not required for simple loading. Importantly,
only call once. Attempting to re-initialize sox will result in seg faults.
"""
import _torch_sox
return _torch_sox.shutdown_sox()
......
from functools import wraps
import torch
from . import _soundfile_backend, _sox_backend
_audio_backend = "sox"
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}
def set_audio_backend(backend):
"""
Specifies the package used to load.
Args:
backend (string): Name of the backend. One of {}.
""".format(_audio_backends.keys())
global _audio_backend
if backend not in _audio_backends:
raise ValueError(
"Invalid backend '{}'. Options are {}.".format(backend, _audio_backends.keys())
)
_audio_backend = backend
def get_audio_backend():
"""
Gets the name of the package used to load.
"""
return _audio_backend
def _get_audio_backend_module():
"""
Gets the module backend to load.
"""
backend = get_audio_backend()
return _audio_backends[backend]
def _audio_backend_guard(backends):
if isinstance(backends, str):
backends = [backends]
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if get_audio_backend() not in backends:
raise RuntimeError("Function {} requires backend to be one of {}.".format(func.__name__, backends))
return func(*args, **kwargs)
return wrapper
return decorator
def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
import os
import torch
_subtype_to_precision = {
'PCM_S8': 8,
'PCM_16': 16,
'PCM_24': 24,
'PCM_32': 32,
'PCM_U8': 8
}
class SignalInfo:
def __init__(self, channels=None, rate=None, precision=None, length=None):
self.channels = channels
self.rate = rate
self.precision = precision
self.length = length
class EncodingInfo:
def __init__(
self,
encoding=None,
bits_per_sample=None,
compression=None,
reverse_bytes=None,
reverse_nibbles=None,
reverse_bits=None,
opposite_endian=None
):
self.encoding = encoding
self.bits_per_sample = bits_per_sample
self.compression = compression
self.reverse_bytes = reverse_bytes
self.reverse_nibbles = reverse_nibbles
self.reverse_bits = reverse_bits
self.opposite_endian = opposite_endian
def check_input(src):
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))
def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
r"""See torchaudio.load"""
assert out is None
assert normalization
assert signalinfo is None
assert encodinginfo is None
# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if num_frames == 0:
num_frames = -1
if offset < 0:
raise ValueError("Expected positive offset value")
import soundfile
# initialize output tensor
# TODO call libsoundfile directly to avoid numpy
out, sample_rate = soundfile.read(
filepath, frames=num_frames, start=offset, dtype="float32", always_2d=True
)
out = torch.from_numpy(out).t()
if not channels_first:
out = out.t()
# normalize if needed
# _audio_normalization(out, normalization)
return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True):
r"""See torchaudio.save"""
ch_idx, len_idx = (0, 1) if channels_first else (1, 0)
# check if save directory exists
abs_dirpath = os.path.dirname(os.path.abspath(filepath))
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
src.unsqueeze_(ch_idx)
elif src.dim() > 2 or src.size(ch_idx) > 16:
# assumes num_channels < 16
raise ValueError(
"Expected format where C < 16, but found {}".format(src.size()))
if channels_first:
src = src.t()
if src.dtype == torch.int64:
# Soundfile doesn't support int64
src = src.type(torch.int32)
precision = "PCM_S8" if precision == 8 else "PCM_" + str(precision)
import soundfile
return soundfile.write(filepath, src, sample_rate, precision)
def info(filepath):
r"""See torchaudio.info"""
import soundfile
sfi = soundfile.info(filepath)
precision = _subtype_to_precision[sfi.subtype]
si = SignalInfo(sfi.channels, sfi.samplerate, precision, sfi.frames)
ei = EncodingInfo(bits_per_sample=precision)
return si, ei
import os.path
import torch
import torchaudio
def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
r"""See torchaudio.load"""
# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
# initialize output tensor
if out is not None:
torchaudio.check_input(out)
else:
out = torch.FloatTensor()
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if offset < 0:
raise ValueError("Expected positive offset value")
import _torch_sox
sample_rate = _torch_sox.read_audio_file(
filepath,
out,
channels_first,
num_frames,
offset,
signalinfo,
encodinginfo,
filetype
)
# normalize if needed
torchaudio._audio_normalization(out, normalization)
return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True):
r"""See torchaudio.save"""
si = torchaudio.sox_signalinfo_t()
ch_idx = 0 if channels_first else 1
si.rate = sample_rate
si.channels = 1 if src.dim() == 1 else src.size(ch_idx)
si.length = src.numel()
si.precision = precision
return torchaudio.save_encinfo(filepath, src, channels_first, si)
def info(filepath):
r"""See torchaudio.info"""
import _torch_sox
return _torch_sox.get_info(filepath)
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
import _torch_sox
import torchaudio
from torchaudio._backend import _audio_backend_guard
@_audio_backend_guard("sox")
def effect_names():
"""Gets list of valid sox effect names
......@@ -13,9 +15,12 @@ def effect_names():
Example
>>> EFFECT_NAMES = torchaudio.sox_effects.effect_names()
"""
import _torch_sox
return _torch_sox.get_effect_names()
@_audio_backend_guard("sox")
def SoxEffect():
r"""Create an object for passing sox effect information between python and c++
......@@ -23,6 +28,8 @@ def SoxEffect():
SoxEffect: An object with the following attributes: ename (str) which is the
name of effect, and eopts (List[str]) which is a list of effect options.
"""
import _torch_sox
return _torch_sox.SoxEffect()
......@@ -71,7 +78,6 @@ class SoxEffectsChain(object):
"""
EFFECTS_AVAILABLE = set(effect_names())
EFFECTS_UNIMPLEMENTED = set(["spectrogram", "splice", "noiseprof", "fir"])
def __init__(self, normalization=True, channels_first=True, out_siginfo=None, out_encinfo=None, filetype="raw"):
......@@ -84,6 +90,9 @@ class SoxEffectsChain(object):
self.normalization = normalization
self.channels_first = channels_first
# Define in __init__ to avoid calling at import time
self.EFFECTS_AVAILABLE = set(effect_names())
def append_effect_to_chain(self, ename, eargs=None):
r"""Append effect to a sox effects chain.
......@@ -107,6 +116,7 @@ class SoxEffectsChain(object):
e.eopts = eargs
self.chain.append(e)
@_audio_backend_guard("sox")
def sox_build_flow_effects(self, out=None):
r"""Build effects chain and flow effects from input file to output tensor
......@@ -130,6 +140,8 @@ class SoxEffectsChain(object):
self.chain.append(e)
# print("effect options:", [x.eopts for x in self.chain])
import _torch_sox
sr = _torch_sox.build_flow_effects(self.input_file,
out,
self.channels_first,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment