Unverified Commit 774ebc78 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Backend switch (#355)

* move sox inside function calls.

* add backend switch mechanism.

* import sox at runtime, not import.

* add backend list.

* backend tests.

* creating hidden modules for backend.

* naming backend same as file: soundfile.

* remove docstring in backend file.

* test soundfile info.

* soundfile doesn't support int64.

* adding test for wav file.

* error with incorrect parameter instead of silent ignore.

* adding test across backend. using float32 as done in sox.

* backend guard decorator.
parent 4887ff41
...@@ -7,14 +7,41 @@ import math ...@@ -7,14 +7,41 @@ import math
import os import os
class AudioBackendScope:
def __init__(self, backend):
self.new_backend = backend
self.previous_backend = torchaudio.get_audio_backend()
def __enter__(self):
torchaudio.set_audio_backend(self.new_backend)
return self.new_backend
def __exit__(self, type, value, traceback):
backend = self.previous_backend
torchaudio.set_audio_backend(backend)
class Test_LoadSave(unittest.TestCase): class Test_LoadSave(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, "assets", test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3") "steam-train-whistle-daniel_simon.mp3")
test_filepath_wav = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.wav")
def test_1_save(self): def test_1_save(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save(self.test_filepath, False)
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save(self.test_filepath_wav, True)
def _test_1_save(self, test_filepath, normalization):
# load signal # load signal
x, sr = torchaudio.load(self.test_filepath, normalization=False) x, sr = torchaudio.load(test_filepath, normalization=normalization)
# check save # check save
new_filepath = os.path.join(self.test_dirpath, "test.wav") new_filepath = os.path.join(self.test_dirpath, "test.wav")
...@@ -52,6 +79,14 @@ class Test_LoadSave(unittest.TestCase): ...@@ -52,6 +79,14 @@ class Test_LoadSave(unittest.TestCase):
"test.wav") "test.wav")
torchaudio.save(new_filepath, x, sr) torchaudio.save(new_filepath, x, sr)
def test_1_save_sine(self):
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save_sine()
def _test_1_save_sine(self):
# save created file # save created file
sinewave_filepath = os.path.join(self.test_dirpath, "assets", sinewave_filepath = os.path.join(self.test_dirpath, "assets",
"sinewave.wav") "sinewave.wav")
...@@ -78,34 +113,36 @@ class Test_LoadSave(unittest.TestCase): ...@@ -78,34 +113,36 @@ class Test_LoadSave(unittest.TestCase):
os.unlink(new_filepath) os.unlink(new_filepath)
def test_2_load(self): def test_2_load(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load(self.test_filepath, 278756)
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load(self.test_filepath_wav, 276858)
def _test_2_load(self, test_filepath, length):
# check normal loading # check normal loading
x, sr = torchaudio.load(self.test_filepath) x, sr = torchaudio.load(test_filepath)
self.assertEqual(sr, 44100) self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (2, 278756)) self.assertEqual(x.size(), (2, length))
# check no normalizing
x, _ = torchaudio.load(self.test_filepath, normalization=False)
self.assertTrue(x.min() <= -1.0)
self.assertTrue(x.max() >= 1.0)
# check offset # check offset
offset = 15 offset = 15
x, _ = torchaudio.load(self.test_filepath) x, _ = torchaudio.load(test_filepath)
x_offset, _ = torchaudio.load(self.test_filepath, offset=offset) x_offset, _ = torchaudio.load(test_filepath, offset=offset)
self.assertTrue(x[:, offset:].allclose(x_offset)) self.assertTrue(x[:, offset:].allclose(x_offset))
# check number of frames # check number of frames
n = 201 n = 201
x, _ = torchaudio.load(self.test_filepath, num_frames=n) x, _ = torchaudio.load(test_filepath, num_frames=n)
self.assertTrue(x.size(), (2, n)) self.assertTrue(x.size(), (2, n))
# check channels first # check channels first
x, _ = torchaudio.load(self.test_filepath, channels_first=False) x, _ = torchaudio.load(test_filepath, channels_first=False)
self.assertEqual(x.size(), (278756, 2)) self.assertEqual(x.size(), (length, 2))
# check different input tensor type
x, _ = torchaudio.load(self.test_filepath, torch.LongTensor(), normalization=False)
self.assertTrue(isinstance(x, torch.LongTensor))
# check raising errors # check raising errors
with self.assertRaises(OSError): with self.assertRaises(OSError):
...@@ -116,7 +153,30 @@ class Test_LoadSave(unittest.TestCase): ...@@ -116,7 +153,30 @@ class Test_LoadSave(unittest.TestCase):
os.path.dirname(self.test_dirpath), "torchaudio") os.path.dirname(self.test_dirpath), "torchaudio")
torchaudio.load(tdir) torchaudio.load(tdir)
def test_2_load_nonormalization(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load_nonormalization(self.test_filepath, 278756)
def _test_2_load_nonormalization(self, test_filepath, length):
# check no normalizing
x, _ = torchaudio.load(test_filepath, normalization=False)
self.assertTrue(x.min() <= -1.0)
self.assertTrue(x.max() >= 1.0)
# check different input tensor type
x, _ = torchaudio.load(test_filepath, torch.LongTensor(), normalization=False)
self.assertTrue(isinstance(x, torch.LongTensor))
def test_3_load_and_save_is_identity(self): def test_3_load_and_save_is_identity(self):
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_3_load_and_save_is_identity()
def _test_3_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path) tensor, sample_rate = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav') output_path = os.path.join(self.test_dirpath, 'test.wav')
...@@ -126,7 +186,35 @@ class Test_LoadSave(unittest.TestCase): ...@@ -126,7 +186,35 @@ class Test_LoadSave(unittest.TestCase):
self.assertEqual(sample_rate, sample_rate2) self.assertEqual(sample_rate, sample_rate2)
os.unlink(output_path) os.unlink(output_path)
def test_3_load_and_save_is_identity_across_backend(self):
with self.subTest():
self._test_3_load_and_save_is_identity_across_backend("sox", "soundfile")
with self.subTest():
self._test_3_load_and_save_is_identity_across_backend("soundfile", "sox")
def _test_3_load_and_save_is_identity_across_backend(self, backend1, backend2):
with AudioBackendScope(backend1):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor1, sample_rate1 = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
torchaudio.save(output_path, tensor1, sample_rate1)
with AudioBackendScope(backend2):
tensor2, sample_rate2 = torchaudio.load(output_path)
self.assertTrue(tensor1.allclose(tensor2))
self.assertEqual(sample_rate1, sample_rate2)
os.unlink(output_path)
def test_4_load_partial(self): def test_4_load_partial(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_4_load_partial()
def _test_4_load_partial(self):
num_frames = 101 num_frames = 101
offset = 201 offset = 201
# load entire mono sinewave wav file, load a partial copy and then compare # load entire mono sinewave wav file, load a partial copy and then compare
...@@ -163,6 +251,12 @@ class Test_LoadSave(unittest.TestCase): ...@@ -163,6 +251,12 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.load(input_sine_path, offset=100000) torchaudio.load(input_sine_path, offset=100000)
def test_5_get_info(self): def test_5_get_info(self):
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_5_get_info()
def _test_5_get_info(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
channels, samples, rate, precision = (1, 64000, 16000, 16) channels, samples, rate, precision = (1, 64000, 16000, 16)
si, ei = torchaudio.info(input_path) si, ei = torchaudio.info(input_path)
......
...@@ -2,9 +2,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -2,9 +2,22 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import os.path import os.path
import torch import torch
import _torch_sox
from torchaudio import transforms, datasets, kaldi_io, sox_effects, compliance from torchaudio import (
compliance,
datasets,
kaldi_io,
sox_effects,
transforms,
)
from torchaudio._backend import (
check_input,
_audio_backend_guard,
_get_audio_backend_module,
get_audio_backend,
set_audio_backend,
)
try: try:
from .version import __version__, git_version # noqa: F401 from .version import __version__, git_version # noqa: F401
...@@ -12,13 +25,6 @@ except ImportError: ...@@ -12,13 +25,6 @@ except ImportError:
pass pass
def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filepath, def load(filepath,
out=None, out=None,
normalization=True, normalization=True,
...@@ -67,36 +73,18 @@ def load(filepath, ...@@ -67,36 +73,18 @@ def load(filepath,
1. 1.
""" """
# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
# initialize output tensor
if out is not None:
check_input(out)
else:
out = torch.FloatTensor()
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if offset < 0:
raise ValueError("Expected positive offset value")
sample_rate = _torch_sox.read_audio_file(filepath, return getattr(_get_audio_backend_module(), 'load')(
out, filepath,
channels_first, out=out,
num_frames, normalization=normalization,
offset, channels_first=channels_first,
signalinfo, num_frames=num_frames,
encodinginfo, offset=offset,
filetype) signalinfo=signalinfo,
encodinginfo=encodinginfo,
# normalize if needed filetype=filetype,
_audio_normalization(out, normalization) )
return out, sample_rate
def load_wav(filepath, **kwargs): def load_wav(filepath, **kwargs):
...@@ -128,15 +116,13 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True): ...@@ -128,15 +116,13 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
channels_first (bool): Set channels first or length first in result. ( channels_first (bool): Set channels first or length first in result. (
Default: ``True``) Default: ``True``)
""" """
si = sox_signalinfo_t()
ch_idx = 0 if channels_first else 1
si.rate = sample_rate
si.channels = 1 if src.dim() == 1 else src.size(ch_idx)
si.length = src.numel()
si.precision = precision
return save_encinfo(filepath, src, channels_first, si)
return getattr(_get_audio_backend_module(), 'save')(
filepath, src, sample_rate, precision=precision, channels_first=channels_first
)
@_audio_backend_guard("sox")
def save_encinfo(filepath, def save_encinfo(filepath,
src, src,
channels_first=True, channels_first=True,
...@@ -203,6 +189,8 @@ def save_encinfo(filepath, ...@@ -203,6 +189,8 @@ def save_encinfo(filepath,
src = src.transpose(1, 0) src = src.transpose(1, 0)
# save data to file # save data to file
src = src.contiguous() src = src.contiguous()
import _torch_sox
_torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype) _torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
...@@ -220,9 +208,11 @@ def info(filepath): ...@@ -220,9 +208,11 @@ def info(filepath):
>>> si, ei = torchaudio.info('foo.wav') >>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding >>> rate, channels, encoding = si.rate, si.channels, ei.encoding
""" """
return _torch_sox.get_info(filepath)
return getattr(_get_audio_backend_module(), 'info')(filepath)
@_audio_backend_guard("sox")
def sox_signalinfo_t(): def sox_signalinfo_t():
r"""Create a sox_signalinfo_t object. This object can be used to set the sample r"""Create a sox_signalinfo_t object. This object can be used to set the sample
rate, number of channels, length, bit precision and headroom multiplier rate, number of channels, length, bit precision and headroom multiplier
...@@ -242,9 +232,12 @@ def sox_signalinfo_t(): ...@@ -242,9 +232,12 @@ def sox_signalinfo_t():
>>> si.precision = 16 >>> si.precision = 16
>>> si.length = 0 >>> si.length = 0
""" """
import _torch_sox
return _torch_sox.sox_signalinfo_t() return _torch_sox.sox_signalinfo_t()
@_audio_backend_guard("sox")
def sox_encodinginfo_t(): def sox_encodinginfo_t():
r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding r"""Create a sox_encodinginfo_t object. This object can be used to set the encoding
type, bit precision, compression factor, reverse bytes, reverse nibbles, type, bit precision, compression factor, reverse bytes, reverse nibbles,
...@@ -274,6 +267,8 @@ def sox_encodinginfo_t(): ...@@ -274,6 +267,8 @@ def sox_encodinginfo_t():
>>> ei.opposite_endian = torchaudio.get_sox_bool(0) >>> ei.opposite_endian = torchaudio.get_sox_bool(0)
""" """
import _torch_sox
ei = _torch_sox.sox_encodinginfo_t() ei = _torch_sox.sox_encodinginfo_t()
sdo = get_sox_option_t(2) # sox_default_option sdo = get_sox_option_t(2) # sox_default_option
ei.reverse_bytes = sdo ei.reverse_bytes = sdo
...@@ -282,6 +277,7 @@ def sox_encodinginfo_t(): ...@@ -282,6 +277,7 @@ def sox_encodinginfo_t():
return ei return ei
@_audio_backend_guard("sox")
def get_sox_encoding_t(i=None): def get_sox_encoding_t(i=None):
r"""Get enum of sox_encoding_t for sox encodings. r"""Get enum of sox_encoding_t for sox encodings.
...@@ -292,6 +288,8 @@ def get_sox_encoding_t(i=None): ...@@ -292,6 +288,8 @@ def get_sox_encoding_t(i=None):
Returns: Returns:
sox_encoding_t: A sox_encoding_t type for output encoding sox_encoding_t: A sox_encoding_t type for output encoding
""" """
import _torch_sox
if i is None: if i is None:
# one can see all possible values using the .__members__ attribute # one can see all possible values using the .__members__ attribute
return _torch_sox.sox_encoding_t return _torch_sox.sox_encoding_t
...@@ -299,6 +297,7 @@ def get_sox_encoding_t(i=None): ...@@ -299,6 +297,7 @@ def get_sox_encoding_t(i=None):
return _torch_sox.sox_encoding_t(i) return _torch_sox.sox_encoding_t(i)
@_audio_backend_guard("sox")
def get_sox_option_t(i=2): def get_sox_option_t(i=2):
r"""Get enum of sox_option_t for sox encodinginfo options. r"""Get enum of sox_option_t for sox encodinginfo options.
...@@ -309,12 +308,15 @@ def get_sox_option_t(i=2): ...@@ -309,12 +308,15 @@ def get_sox_option_t(i=2):
Returns: Returns:
sox_option_t: A sox_option_t type sox_option_t: A sox_option_t type
""" """
import _torch_sox
if i is None: if i is None:
return _torch_sox.sox_option_t return _torch_sox.sox_option_t
else: else:
return _torch_sox.sox_option_t(i) return _torch_sox.sox_option_t(i)
@_audio_backend_guard("sox")
def get_sox_bool(i=0): def get_sox_bool(i=0):
r"""Get enum of sox_bool for sox encodinginfo options. r"""Get enum of sox_bool for sox encodinginfo options.
...@@ -326,24 +328,32 @@ def get_sox_bool(i=0): ...@@ -326,24 +328,32 @@ def get_sox_bool(i=0):
Returns: Returns:
sox_bool: A sox_bool type sox_bool: A sox_bool type
""" """
import _torch_sox
if i is None: if i is None:
return _torch_sox.sox_bool return _torch_sox.sox_bool
else: else:
return _torch_sox.sox_bool(i) return _torch_sox.sox_bool(i)
@_audio_backend_guard("sox")
def initialize_sox(): def initialize_sox():
"""Initialize sox for use with effects chains. This is not required for simple """Initialize sox for use with effects chains. This is not required for simple
loading. Importantly, only run `initialize_sox` once and do not shutdown loading. Importantly, only run `initialize_sox` once and do not shutdown
after each effect chain, but rather once you are finished with all effects chains. after each effect chain, but rather once you are finished with all effects chains.
""" """
import _torch_sox
return _torch_sox.initialize_sox() return _torch_sox.initialize_sox()
@_audio_backend_guard("sox")
def shutdown_sox(): def shutdown_sox():
"""Showdown sox for effects chain. Not required for simple loading. Importantly, """Showdown sox for effects chain. Not required for simple loading. Importantly,
only call once. Attempting to re-initialize sox will result in seg faults. only call once. Attempting to re-initialize sox will result in seg faults.
""" """
import _torch_sox
return _torch_sox.shutdown_sox() return _torch_sox.shutdown_sox()
......
from functools import wraps
import torch
from . import _soundfile_backend, _sox_backend
_audio_backend = "sox"
_audio_backends = {"sox": _sox_backend, "soundfile": _soundfile_backend}
def set_audio_backend(backend):
"""
Specifies the package used to load.
Args:
backend (string): Name of the backend. One of {}.
""".format(_audio_backends.keys())
global _audio_backend
if backend not in _audio_backends:
raise ValueError(
"Invalid backend '{}'. Options are {}.".format(backend, _audio_backends.keys())
)
_audio_backend = backend
def get_audio_backend():
"""
Gets the name of the package used to load.
"""
return _audio_backend
def _get_audio_backend_module():
"""
Gets the module backend to load.
"""
backend = get_audio_backend()
return _audio_backends[backend]
def _audio_backend_guard(backends):
if isinstance(backends, str):
backends = [backends]
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
if get_audio_backend() not in backends:
raise RuntimeError("Function {} requires backend to be one of {}.".format(func.__name__, backends))
return func(*args, **kwargs)
return wrapper
return decorator
def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
import os
import torch
_subtype_to_precision = {
'PCM_S8': 8,
'PCM_16': 16,
'PCM_24': 24,
'PCM_32': 32,
'PCM_U8': 8
}
class SignalInfo:
def __init__(self, channels=None, rate=None, precision=None, length=None):
self.channels = channels
self.rate = rate
self.precision = precision
self.length = length
class EncodingInfo:
def __init__(
self,
encoding=None,
bits_per_sample=None,
compression=None,
reverse_bytes=None,
reverse_nibbles=None,
reverse_bits=None,
opposite_endian=None
):
self.encoding = encoding
self.bits_per_sample = bits_per_sample
self.compression = compression
self.reverse_bytes = reverse_bytes
self.reverse_nibbles = reverse_nibbles
self.reverse_bits = reverse_bits
self.opposite_endian = opposite_endian
def check_input(src):
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))
def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
r"""See torchaudio.load"""
assert out is None
assert normalization
assert signalinfo is None
assert encodinginfo is None
# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if num_frames == 0:
num_frames = -1
if offset < 0:
raise ValueError("Expected positive offset value")
import soundfile
# initialize output tensor
# TODO call libsoundfile directly to avoid numpy
out, sample_rate = soundfile.read(
filepath, frames=num_frames, start=offset, dtype="float32", always_2d=True
)
out = torch.from_numpy(out).t()
if not channels_first:
out = out.t()
# normalize if needed
# _audio_normalization(out, normalization)
return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True):
r"""See torchaudio.save"""
ch_idx, len_idx = (0, 1) if channels_first else (1, 0)
# check if save directory exists
abs_dirpath = os.path.dirname(os.path.abspath(filepath))
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
src.unsqueeze_(ch_idx)
elif src.dim() > 2 or src.size(ch_idx) > 16:
# assumes num_channels < 16
raise ValueError(
"Expected format where C < 16, but found {}".format(src.size()))
if channels_first:
src = src.t()
if src.dtype == torch.int64:
# Soundfile doesn't support int64
src = src.type(torch.int32)
precision = "PCM_S8" if precision == 8 else "PCM_" + str(precision)
import soundfile
return soundfile.write(filepath, src, sample_rate, precision)
def info(filepath):
r"""See torchaudio.info"""
import soundfile
sfi = soundfile.info(filepath)
precision = _subtype_to_precision[sfi.subtype]
si = SignalInfo(sfi.channels, sfi.samplerate, precision, sfi.frames)
ei = EncodingInfo(bits_per_sample=precision)
return si, ei
import os.path
import torch
import torchaudio
def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
signalinfo=None,
encodinginfo=None,
filetype=None,
):
r"""See torchaudio.load"""
# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
# initialize output tensor
if out is not None:
torchaudio.check_input(out)
else:
out = torch.FloatTensor()
if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if offset < 0:
raise ValueError("Expected positive offset value")
import _torch_sox
sample_rate = _torch_sox.read_audio_file(
filepath,
out,
channels_first,
num_frames,
offset,
signalinfo,
encodinginfo,
filetype
)
# normalize if needed
torchaudio._audio_normalization(out, normalization)
return out, sample_rate
def save(filepath, src, sample_rate, precision=16, channels_first=True):
r"""See torchaudio.save"""
si = torchaudio.sox_signalinfo_t()
ch_idx = 0 if channels_first else 1
si.rate = sample_rate
si.channels = 1 if src.dim() == 1 else src.size(ch_idx)
si.length = src.numel()
si.precision = precision
return torchaudio.save_encinfo(filepath, src, channels_first, si)
def info(filepath):
r"""See torchaudio.info"""
import _torch_sox
return _torch_sox.get_info(filepath)
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import torch import torch
import _torch_sox
import torchaudio import torchaudio
from torchaudio._backend import _audio_backend_guard
@_audio_backend_guard("sox")
def effect_names(): def effect_names():
"""Gets list of valid sox effect names """Gets list of valid sox effect names
...@@ -13,9 +15,12 @@ def effect_names(): ...@@ -13,9 +15,12 @@ def effect_names():
Example Example
>>> EFFECT_NAMES = torchaudio.sox_effects.effect_names() >>> EFFECT_NAMES = torchaudio.sox_effects.effect_names()
""" """
import _torch_sox
return _torch_sox.get_effect_names() return _torch_sox.get_effect_names()
@_audio_backend_guard("sox")
def SoxEffect(): def SoxEffect():
r"""Create an object for passing sox effect information between python and c++ r"""Create an object for passing sox effect information between python and c++
...@@ -23,6 +28,8 @@ def SoxEffect(): ...@@ -23,6 +28,8 @@ def SoxEffect():
SoxEffect: An object with the following attributes: ename (str) which is the SoxEffect: An object with the following attributes: ename (str) which is the
name of effect, and eopts (List[str]) which is a list of effect options. name of effect, and eopts (List[str]) which is a list of effect options.
""" """
import _torch_sox
return _torch_sox.SoxEffect() return _torch_sox.SoxEffect()
...@@ -71,7 +78,6 @@ class SoxEffectsChain(object): ...@@ -71,7 +78,6 @@ class SoxEffectsChain(object):
""" """
EFFECTS_AVAILABLE = set(effect_names())
EFFECTS_UNIMPLEMENTED = set(["spectrogram", "splice", "noiseprof", "fir"]) EFFECTS_UNIMPLEMENTED = set(["spectrogram", "splice", "noiseprof", "fir"])
def __init__(self, normalization=True, channels_first=True, out_siginfo=None, out_encinfo=None, filetype="raw"): def __init__(self, normalization=True, channels_first=True, out_siginfo=None, out_encinfo=None, filetype="raw"):
...@@ -84,6 +90,9 @@ class SoxEffectsChain(object): ...@@ -84,6 +90,9 @@ class SoxEffectsChain(object):
self.normalization = normalization self.normalization = normalization
self.channels_first = channels_first self.channels_first = channels_first
# Define in __init__ to avoid calling at import time
self.EFFECTS_AVAILABLE = set(effect_names())
def append_effect_to_chain(self, ename, eargs=None): def append_effect_to_chain(self, ename, eargs=None):
r"""Append effect to a sox effects chain. r"""Append effect to a sox effects chain.
...@@ -107,6 +116,7 @@ class SoxEffectsChain(object): ...@@ -107,6 +116,7 @@ class SoxEffectsChain(object):
e.eopts = eargs e.eopts = eargs
self.chain.append(e) self.chain.append(e)
@_audio_backend_guard("sox")
def sox_build_flow_effects(self, out=None): def sox_build_flow_effects(self, out=None):
r"""Build effects chain and flow effects from input file to output tensor r"""Build effects chain and flow effects from input file to output tensor
...@@ -130,6 +140,8 @@ class SoxEffectsChain(object): ...@@ -130,6 +140,8 @@ class SoxEffectsChain(object):
self.chain.append(e) self.chain.append(e)
# print("effect options:", [x.eopts for x in self.chain]) # print("effect options:", [x.eopts for x in self.chain])
import _torch_sox
sr = _torch_sox.build_flow_effects(self.input_file, sr = _torch_sox.build_flow_effects(self.input_file,
out, out,
self.channels_first, self.channels_first,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment