Unverified Commit 793eeab8 authored by moto's avatar moto Committed by GitHub
Browse files

Add load function (#731)

This is a part of PRs to add new "sox_io" backend. #726 and depends on #718 and #728 .

This PR adds `load` function to "sox_io" backend, which is  tested on the following audio formats;
 - `wav`
 - `mp3`
 - `flac`
 - `ogg/vorbis` *

By default, "sox_io" backend returns Tensor with `float32` dtype and the shape of `[channel, time]`. The samples are normalized to fit in the range of `[-1.0, 1.0]`.

Unlike existing "sox" backend, the new `load` function can handle WAV file natively, when the input format is WAV with integer type, (such as 32-bit signed integer, 16-bit signed integer and 8-bit unsigned integer) by providing `normalize=False`, this function can return integer Tensor, where the samples are expressed within the whole range of the corresponding dtype, that is, `int32` tensor for `32-bit PCM`, `int16` for `16-bit PCM` and `uint8` for `8-bit PCM`. This behavior follows [scipy.io.wavfile.read](https://docs.scipy.org/doc/scipy/reference/generated/scipy.io.wavfile.read.html). `normalize` parameter has no effect for other formats and the load function always return normalized value with `float32` Tensor.

__* Note__ The current binary distribution of torchaudio does not contain `ogg/vorbis` and `opus` codecs. To handle these files, one needs to build torchaudio from the source with proper codecs in the system.

__Note 2__ Since this PR, `scipy` becomes required module for running test. 
parent 0f0d0af3
from typing import Optional
import torch
import scipy.io.wavfile
def get_test_name(func, _, params): def get_test_name(func, _, params):
return f'{func.__name__}_{"_".join(str(p) for p in params.args)}' return f'{func.__name__}_{"_".join(str(p) for p in params.args)}'
def normalize_wav(tensor: torch.Tensor) -> torch.Tensor:
if tensor.dtype == torch.float32:
pass
elif tensor.dtype == torch.int32:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 2147483647.
tensor[tensor < 0] /= 2147483648.
elif tensor.dtype == torch.int16:
tensor = tensor.to(torch.float32)
tensor[tensor > 0] /= 32767.
tensor[tensor < 0] /= 32768.
elif tensor.dtype == torch.uint8:
tensor = tensor.to(torch.float32) - 128
tensor[tensor > 0] /= 127.
tensor[tensor < 0] /= 128.
return tensor
def get_wav_data(
dtype: str,
num_channels: int,
*,
num_frames: Optional[int] = None,
normalize: bool = True,
channels_first: bool = True,
):
"""Generate linear signal of the given dtype and num_channels
Data range is
[-1.0, 1.0] for float32,
[-2147483648, 2147483647] for int32
[-32768, 32767] for int16
[0, 255] for uint8
num_frames allow to change the linear interpolation parameter.
Default values are 256 for uint8, else 1 << 16.
1 << 16 as default is so that int16 value range is completely covered.
"""
dtype_ = getattr(torch, dtype)
if num_frames is None:
if dtype == 'uint8':
num_frames = 256
else:
num_frames = 1 << 16
if dtype == 'uint8':
base = torch.linspace(0, 255, num_frames, dtype=dtype_)
if dtype == 'float32':
base = torch.linspace(-1., 1., num_frames, dtype=dtype_)
if dtype == 'int32':
base = torch.linspace(-2147483648, 2147483647, num_frames, dtype=dtype_)
if dtype == 'int16':
base = torch.linspace(-32768, 32767, num_frames, dtype=dtype_)
data = base.repeat([num_channels, 1])
if not channels_first:
data = data.transpose(1, 0)
if normalize:
data = normalize_wav(data)
return data
def load_wav(path: str, normalize=True, channels_first=True) -> torch.Tensor:
"""Load wav file without torchaudio"""
sample_rate, data = scipy.io.wavfile.read(path)
data = torch.from_numpy(data.copy())
if data.ndim == 1:
data = data.unsqueeze(1)
if normalize:
data = normalize_wav(data)
if channels_first:
data = data.transpose(1, 0)
return data, sample_rate
def save_wav(path, data, sample_rate, channels_first=True):
"""Save wav file without torchaudio"""
if channels_first:
data = data.transpose(1, 0)
scipy.io.wavfile.write(path, sample_rate, data.numpy())
...@@ -26,6 +26,9 @@ def gen_audio_file( ...@@ -26,6 +26,9 @@ def gen_audio_file(
*, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1, *, encoding=None, bit_depth=None, compression=None, attenuation=None, duration=1,
): ):
"""Generate synthetic audio file with `sox` command.""" """Generate synthetic audio file with `sox` command."""
if path.endswith('.wav'):
raise RuntimeError(
'Use get_wav_data and save_wav to generate wav file for accurate result.')
command = [ command = [
'sox', 'sox',
'-V', # verbose '-V', # verbose
...@@ -51,4 +54,17 @@ def gen_audio_file( ...@@ -51,4 +54,17 @@ def gen_audio_file(
command += ['vol', f'-{attenuation}dB'] command += ['vol', f'-{attenuation}dB']
print(' '.join(command)) print(' '.join(command))
subprocess.run(command, check=True) subprocess.run(command, check=True)
subprocess.run(['soxi', path], check=True)
def convert_audio_file(
src_path, dst_path,
*, bit_depth=None, compression=None):
"""Convert audio file with `sox` command."""
command = ['sox', '-V', str(src_path)]
if bit_depth is not None:
command += ['--bits', str(bit_depth)]
if compression is not None:
command += ['--compression', str(compression)]
command += [dst_path]
print(' '.join(command))
subprocess.run(command, check=True)
...@@ -10,7 +10,9 @@ from ..common_utils import ( ...@@ -10,7 +10,9 @@ from ..common_utils import (
skipIfNoExtension, skipIfNoExtension,
) )
from .common import ( from .common import (
get_test_name get_test_name,
get_wav_data,
save_wav,
) )
from . import sox_utils from . import sox_utils
...@@ -27,12 +29,8 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -27,12 +29,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
"""`sox_io_backend.info` can check wav file correctly""" """`sox_io_backend.info` can check wav file correctly"""
duration = 1 duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file( data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
path, sample_rate, num_channels, save_wav(path, data, sample_rate)
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration assert info.get_num_frames() == sample_rate * duration
...@@ -47,12 +45,8 @@ class TestInfo(TempDirMixin, PytorchTestCase): ...@@ -47,12 +45,8 @@ class TestInfo(TempDirMixin, PytorchTestCase):
"""`sox_io_backend.info` can check wav file with channels more than 2 correctly""" """`sox_io_backend.info` can check wav file with channels more than 2 correctly"""
duration = 1 duration = 1
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file( data = get_wav_data(dtype, num_channels, normalize=False, num_frames=duration * sample_rate)
path, sample_rate, num_channels, save_wav(path, data, sample_rate)
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
duration=duration,
)
info = sox_io_backend.info(path) info = sox_io_backend.info(path)
assert info.get_sample_rate() == sample_rate assert info.get_sample_rate() == sample_rate
assert info.get_num_frames() == sample_rate * duration assert info.get_num_frames() == sample_rate * duration
......
import itertools
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from ..common_utils import (
TempDirMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
)
from .common import (
get_test_name,
get_wav_data,
load_wav,
save_wav,
)
from . import sox_utils
class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
Wav data loaded with sox_io backend should match those with scipy
"""
path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
data = get_wav_data(dtype, num_channels, normalize=normalize, num_frames=duration * sample_rate)
save_wav(path, data, sample_rate)
expected = load_wav(path, normalize=normalize)[0]
data, sr = sox_io_backend.load(path, normalize=normalize)
assert sr == sample_rate
self.assertEqual(data, expected)
def assert_mp3(self, sample_rate, num_channels, bit_rate, duration):
"""`sox_io_backend.load` can load mp3 format.
mp3 encoding introduces delay and boundary effects so
we create reference wav file from mp3
x
|
| 1. Generate mp3 with Sox
|
v 2. Convert to wav with Sox
mp3 ------------------------------> wav
| |
| 3. Load with torchaudio | 4. Load with scipy
| |
v v
tensor ----------> x <----------- tensor
5. Compare
Underlying assumptions are;
i. Conversion of mp3 to wav with Sox preserves data.
ii. Loading wav file with scipy is correct.
By combining i & ii, step 2. and 4. allows to load reference mp3 data
without using torchaudio
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{bit_rate}_{duration}.mp3')
ref_path = f'{path}.wav'
# 1. Generate mp3 with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=bit_rate, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load mp3 with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=3e-03, rtol=1.3e-06)
def assert_flac(self, sample_rate, num_channels, compression_level, duration):
"""`sox_io_backend.load` can load flac format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{compression_level}_{duration}.flac')
ref_path = f'{path}.wav'
# 1. Generate flac with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=compression_level, bit_depth=16, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load flac with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
def assert_vorbis(self, sample_rate, num_channels, quality_level, duration):
"""`sox_io_backend.load` can load vorbis format.
This test takes the same strategy as mp3 to compare the result
"""
path = self.get_temp_path(f'{sample_rate}_{num_channels}_{quality_level}_{duration}.vorbis')
ref_path = f'{path}.wav'
# 1. Generate vorbis with sox
sox_utils.gen_audio_file(
path, sample_rate, num_channels,
compression=quality_level, bit_depth=16, duration=duration)
# 2. Convert to wav with sox
sox_utils.convert_audio_file(path, ref_path)
# 3. Load vorbis with torchaudio
data, sr = sox_io_backend.load(path)
# 4. Load wav with scipy
data_ref = load_wav(ref_path)[0]
# 5. Compare
assert sr == sample_rate
self.assertEqual(data, data_ref, atol=4e-05, rtol=1.3e-06)
@skipIfNoExec('sox')
@skipIfNoExtension
class TestLoad(LoadTestBase):
"""Test the correctness of `sox_io_backend.load` for various formats"""
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
)), name_func=get_test_name)
def test_wav(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load wav format correctly."""
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product(
['int16'],
[16000],
[2],
[False],
)), name_func=get_test_name)
def test_wav_large(self, dtype, sample_rate, num_channels, normalize):
"""`sox_io_backend.load` can load large wav file correctly."""
two_hours = 2 * 60 * 60
self.assert_wav(dtype, sample_rate, num_channels, normalize, two_hours)
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[4, 8, 16, 32],
)), name_func=get_test_name)
def test_multiple_channels(self, dtype, num_channels):
"""`sox_io_backend.load` can load wav file with more than 2 channels."""
sample_rate = 8000
normalize = False
self.assert_wav(dtype, sample_rate, num_channels, normalize, duration=1)
@parameterized.expand(list(itertools.product(
[8000, 16000, 44100],
[1, 2],
[96, 128, 160, 192, 224, 256, 320],
)), name_func=get_test_name)
def test_mp3(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load mp3 format correctly."""
self.assert_mp3(sample_rate, num_channels, bit_rate, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[128],
)), name_func=get_test_name)
def test_mp3_large(self, sample_rate, num_channels, bit_rate):
"""`sox_io_backend.load` can load large mp3 file correctly."""
two_hours = 2 * 60 * 60
self.assert_mp3(sample_rate, num_channels, bit_rate, two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
list(range(9)),
)), name_func=get_test_name)
def test_flac(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load flac format correctly."""
self.assert_flac(sample_rate, num_channels, compression_level, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[0],
)), name_func=get_test_name)
def test_flac_large(self, sample_rate, num_channels, compression_level):
"""`sox_io_backend.load` can load large flac file correctly."""
two_hours = 2 * 60 * 60
self.assert_flac(sample_rate, num_channels, compression_level, two_hours)
@parameterized.expand(list(itertools.product(
[8000, 16000],
[1, 2],
[-1, 0, 1, 2, 3, 3.6, 5, 10],
)), name_func=get_test_name)
def test_vorbis(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load vorbis format correctly."""
self.assert_vorbis(sample_rate, num_channels, quality_level, duration=1)
@parameterized.expand(list(itertools.product(
[16000],
[2],
[10],
)), name_func=get_test_name)
def test_vorbis_large(self, sample_rate, num_channels, quality_level):
"""`sox_io_backend.load` can load large vorbis file correctly."""
two_hours = 2 * 60 * 60
self.assert_vorbis(sample_rate, num_channels, quality_level, two_hours)
@skipIfNoExec('sox')
@skipIfNoExtension
class TestLoadParams(TempDirMixin, PytorchTestCase):
"""Test the correctness of frame parameters of `sox_io_backend.load`"""
original = None
path = None
def setUp(self):
super().setUp()
sample_rate = 8000
self.original = get_wav_data('float32', num_channels=2)
self.path = self.get_temp_path('test.wave')
save_wav(self.path, self.original, sample_rate)
@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=get_test_name)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
found, _ = sox_io_backend.load(self.path, frame_offset, num_frames)
frame_end = None if num_frames == -1 else frame_offset + num_frames
self.assertEqual(found, self.original[:, frame_offset:frame_end])
@parameterized.expand([(True, ), (False, )], name_func=get_test_name)
def test_channels_first(self, channels_first):
"""channels_first swaps axes"""
found, _ = sox_io_backend.load(self.path, channels_first=channels_first)
expected = self.original if channels_first else self.original.transpose(1, 0)
self.assertEqual(found, expected)
...@@ -12,29 +12,34 @@ from ..common_utils import ( ...@@ -12,29 +12,34 @@ from ..common_utils import (
) )
from .common import ( from .common import (
get_test_name, get_test_name,
get_wav_data,
save_wav
) )
from . import sox_utils
def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo: def py_info_func(filepath: str) -> torch.classes.torchaudio.SignalInfo:
return sox_io_backend.info(filepath) return sox_io_backend.info(filepath)
def py_load_func(filepath: str, normalize: bool, channels_first: bool):
return sox_io_backend.load(
filepath, normalize=normalize, channels_first=channels_first)
@skipIfNoExec('sox') @skipIfNoExec('sox')
@skipIfNoExtension @skipIfNoExtension
class SoxIO(TempDirMixin, TorchaudioTestCase): class SoxIO(TempDirMixin, TorchaudioTestCase):
"""TorchScript-ability Test suite for `sox_io_backend`"""
@parameterized.expand(list(itertools.product( @parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'], ['float32', 'int32', 'int16', 'uint8'],
[8000, 16000], [8000, 16000],
[1, 2], [1, 2],
)), name_func=get_test_name) )), name_func=get_test_name)
def test_info_wav(self, dtype, sample_rate, num_channels): def test_info_wav(self, dtype, sample_rate, num_channels):
"""`sox_io_backend.info` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav') audio_path = self.get_temp_path(f'{dtype}_{sample_rate}_{num_channels}.wav')
sox_utils.gen_audio_file( data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
audio_path, sample_rate, num_channels, save_wav(audio_path, data, sample_rate)
bit_depth=sox_utils.get_bit_depth(dtype),
encoding=sox_utils.get_encoding(dtype),
)
script_path = self.get_temp_path('info_func') script_path = self.get_temp_path('info_func')
torch.jit.script(py_info_func).save(script_path) torch.jit.script(py_info_func).save(script_path)
...@@ -46,3 +51,28 @@ class SoxIO(TempDirMixin, TorchaudioTestCase): ...@@ -46,3 +51,28 @@ class SoxIO(TempDirMixin, TorchaudioTestCase):
assert py_info.get_sample_rate() == ts_info.get_sample_rate() assert py_info.get_sample_rate() == ts_info.get_sample_rate()
assert py_info.get_num_frames() == ts_info.get_num_frames() assert py_info.get_num_frames() == ts_info.get_num_frames()
assert py_info.get_num_channels() == ts_info.get_num_channels() assert py_info.get_num_channels() == ts_info.get_num_channels()
@parameterized.expand(list(itertools.product(
['float32', 'int32', 'int16', 'uint8'],
[8000, 16000],
[1, 2],
[False, True],
[False, True],
)), name_func=get_test_name)
def test_load_wav(self, dtype, sample_rate, num_channels, normalize, channels_first):
"""`sox_io_backend.load` is torchscript-able and returns the same result"""
audio_path = self.get_temp_path(f'test_load_{dtype}_{sample_rate}_{num_channels}_{normalize}.wav')
data = get_wav_data(dtype, num_channels, normalize=False, num_frames=1 * sample_rate)
save_wav(audio_path, data, sample_rate)
script_path = self.get_temp_path('load_func')
torch.jit.script(py_load_func).save(script_path)
ts_load_func = torch.jit.load(script_path)
py_data, py_sr = py_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
ts_data, ts_sr = ts_load_func(
audio_path, normalize=normalize, channels_first=channels_first)
self.assertEqual(py_sr, ts_sr)
self.assertEqual(py_data, ts_data)
from typing import Tuple
import torch import torch
from torchaudio._internal import ( from torchaudio._internal import (
module_utils as _mod_utils, module_utils as _mod_utils,
...@@ -8,3 +10,71 @@ from torchaudio._internal import ( ...@@ -8,3 +10,71 @@ from torchaudio._internal import (
def info(filepath: str) -> torch.classes.torchaudio.SignalInfo: def info(filepath: str) -> torch.classes.torchaudio.SignalInfo:
"""Get signal information of an audio file.""" """Get signal information of an audio file."""
return torch.ops.torchaudio.sox_io_get_info(filepath) return torch.ops.torchaudio.sox_io_get_info(filepath)
@_mod_utils.requires_module('torchaudio._torchaudio')
def load(
filepath: str,
frame_offset: int = 0,
num_frames: int = -1,
normalize: bool = True,
channels_first: bool = True,
) -> Tuple[torch.Tensor, int]:
"""Load audio data from file.
This function can handle all the codecs that underlying libsox can handle, however note the
followings.
Note 1:
Current torchaudio's binary release only contains codecs for MP3, FLAC and OGG/VORBIS.
If you need other formats, you need to build torchaudio from source with libsox and
the corresponding codecs. Refer to README for this.
Note 2:
This function is tested on the following formats;
- WAV
- 32-bit floating-point
- 32-bit signed integer
- 16-bit signed integer
- 8-bit unsigned integer
- MP3
- FLAC
- OGG/VORBIS
By default, this function returns Tensor with ``float32`` dtype and the shape of ``[channel, time]``.
The samples are normalized to fit in the range of ``[-1.0, 1.0]``.
When the input format is WAV with integer type, such as 32-bit signed integer, 16-bit
signed integer and 8-bit unsigned integer (24-bit signed integer is not supported),
by providing ``normalize=False``, this function can return integer Tensor, where the samples
are expressed within the whole range of the corresponding dtype, that is, ``int32`` tensor
for 32-bit signed PCM, ``int16`` for 16-bit signed PCM and ``uint8`` for 8-bit unsigned PCM.
``normalize`` parameter has no effect on 32-bit floating-point WAV and other formats, such as
flac and mp3. For these formats, this function always returns ``float32`` Tensor with values
normalized to ``[-1.0, 1.0]``.
Args:
filepath: Path to audio file
frame_offset: Number of frames to skip before start reading data.
num_frames: Maximum number of frames to read. -1 reads all the remaining samples, starting
from ``frame_offset``. This function may return the less number of frames if there is
not enough frames in the given file.
normalize: When ``True``, this function always return ``float32``, and sample values are
normalized to ``[-1.0, 1.0]``. If input file is integer WAV, giving ``False`` will change
the resulting Tensor type to integer type. This argument has no effect for formats other
than integer WAV type.
channels_first: When True, the returned Tensor has dimension ``[channel, time]``.
Otherwise, the returned Tensor's dimension is ``[time, channel]``.
Returns:
torch.Tensor: If the input file has integer wav format and normalization is off, then it has
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first)
return signal.get_tensor(), signal.get_sample_rate()
load_wav = load
...@@ -3,11 +3,15 @@ ...@@ -3,11 +3,15 @@
#include <torchaudio/csrc/sox_effects.h> #include <torchaudio/csrc/sox_effects.h>
#include <torchaudio/csrc/sox_io.h> #include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h> #include <torchaudio/csrc/typedefs.h>
namespace torchaudio { namespace torchaudio {
namespace { namespace {
////////////////////////////////////////////////////////////////////////////////
// typedefs.h
////////////////////////////////////////////////////////////////////////////////
static auto registerSignalInfo = static auto registerSignalInfo =
torch::class_<SignalInfo>("torchaudio", "SignalInfo") torch::class_<SignalInfo>("torchaudio", "SignalInfo")
.def(torch::init<int64_t, int64_t, int64_t>()) .def(torch::init<int64_t, int64_t, int64_t>())
...@@ -15,12 +19,33 @@ static auto registerSignalInfo = ...@@ -15,12 +19,33 @@ static auto registerSignalInfo =
.def("get_num_channels", &SignalInfo::getNumChannels) .def("get_num_channels", &SignalInfo::getNumChannels)
.def("get_num_frames", &SignalInfo::getNumFrames); .def("get_num_frames", &SignalInfo::getNumFrames);
////////////////////////////////////////////////////////////////////////////////
// sox_utils.h
////////////////////////////////////////////////////////////////////////////////
static auto registerTensorSignal =
torch::class_<sox_utils::TensorSignal>("torchaudio", "TensorSignal")
.def(torch::init<torch::Tensor, int64_t, bool>())
.def("get_tensor", &sox_utils::TensorSignal::getTensor)
.def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate)
.def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst);
////////////////////////////////////////////////////////////////////////////////
// sox_io.h
////////////////////////////////////////////////////////////////////////////////
static auto registerGetInfo = torch::RegisterOperators().op( static auto registerGetInfo = torch::RegisterOperators().op(
torch::RegisterOperators::options() torch::RegisterOperators::options()
.schema( .schema(
"torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info") "torchaudio::sox_io_get_info(str path) -> __torch__.torch.classes.torchaudio.SignalInfo info")
.catchAllKernel<decltype(sox_io::get_info), &sox_io::get_info>()); .catchAllKernel<decltype(sox_io::get_info), &sox_io::get_info>());
static auto registerLoadAudioFile = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"torchaudio::sox_io_load_audio_file(str path, int frame_offset, int num_frames, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal signal")
.catchAllKernel<
decltype(sox_io::load_audio_file),
&sox_io::load_audio_file>());
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// sox_effects.h // sox_effects.h
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
#include <sox.h> #include <sox.h>
#include <torchaudio/csrc/sox_io.h> #include <torchaudio/csrc/sox_io.h>
#include <torchaudio/csrc/sox_utils.h>
using namespace torch::indexing; using namespace torch::indexing;
using namespace torchaudio::sox_utils;
namespace torchaudio { namespace torchaudio {
namespace sox_io { namespace sox_io {
namespace { c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path) {
SoxFormat sf(sox_open_read(
/// Helper struct to safely close the sox_format_t descriptor. path.c_str(),
struct SoxDescriptor { /*signal=*/nullptr,
explicit SoxDescriptor(sox_format_t* fd) noexcept : fd_(fd) {} /*encoding=*/nullptr,
SoxDescriptor(const SoxDescriptor& other) = delete; /*filetype=*/nullptr));
SoxDescriptor(SoxDescriptor&& other) = delete;
SoxDescriptor& operator=(const SoxDescriptor& other) = delete; if (sf.get() == nullptr) {
SoxDescriptor& operator=(SoxDescriptor&& other) = delete; throw std::runtime_error("Error opening audio file");
~SoxDescriptor() {
if (fd_ != nullptr) {
sox_close(fd_);
}
}
sox_format_t* operator->() noexcept {
return fd_;
}
sox_format_t* get() noexcept {
return fd_;
} }
private: return c10::make_intrusive<torchaudio::SignalInfo>(
sox_format_t* fd_; static_cast<int64_t>(sf->signal.rate),
}; static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
}
} // namespace c10::intrusive_ptr<TensorSignal> load_audio_file(
const std::string& path,
const int64_t frame_offset,
const int64_t num_frames,
const bool normalize,
const bool channels_first) {
if (frame_offset < 0) {
throw std::runtime_error(
"Invalid argument: frame_offset must be non-negative.");
}
if (num_frames == 0 || num_frames < -1) {
throw std::runtime_error(
"Invalid argument: num_frames must be -1 or greater than 0.");
}
c10::intrusive_ptr<::torchaudio::SignalInfo> get_info( SoxFormat sf(sox_open_read(
const std::string& file_name) { path.c_str(),
SoxDescriptor sd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr, /*signal=*/nullptr,
/*encoding=*/nullptr, /*encoding=*/nullptr,
/*filetype=*/nullptr)); /*filetype=*/nullptr));
if (sd.get() == nullptr) { validate_input_file(sf);
throw std::runtime_error("Error opening audio file");
const int64_t num_channels = sf->signal.channels;
const int64_t num_total_samples = sf->signal.length;
const int64_t sample_start = sf->signal.channels * frame_offset;
if (sox_seek(sf.get(), sample_start, 0) == SOX_EOF) {
throw std::runtime_error("Error reading audio file: offset past EOF.");
}
const int64_t sample_end = [&]() {
if (num_frames == -1)
return num_total_samples;
const int64_t sample_end_ = num_channels * num_frames + sample_start;
if (num_total_samples < sample_end_) {
// For lossy encoding, it is difficult to predict exact size of buffer for
// reading the number of samples required.
// So we allocate buffer size of given `num_frames` and ask sox to read as
// much as possible. For lossless format, sox reads exact number of
// samples, but for lossy encoding, sox can end up reading less. (i.e.
// mp3) For the consistent behavior specification between lossy/lossless
// format, we allow users to provide `num_frames` value that exceeds #of
// available samples, and we adjust it here.
return num_total_samples;
} }
return sample_end_;
}();
const int64_t max_samples = sample_end - sample_start;
// Read samples into buffer
std::vector<sox_sample_t> buffer;
buffer.reserve(max_samples);
const int64_t num_samples = sox_read(sf.get(), buffer.data(), max_samples);
if (num_samples == 0) {
throw std::runtime_error(
"Error reading audio file: empty file or read operation failed.");
}
// NOTE: num_samples may be smaller than max_samples if the input
// format is compressed (i.e. mp3).
// Convert to Tensor
auto tensor = convert_to_tensor(
buffer.data(),
num_samples,
num_channels,
get_dtype(sf->encoding.encoding, sf->signal.precision),
normalize,
channels_first);
return c10::make_intrusive<::torchaudio::SignalInfo>( return c10::make_intrusive<TensorSignal>(
static_cast<int64_t>(sd->signal.rate), tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
static_cast<int64_t>(sd->signal.channels),
static_cast<int64_t>(sd->signal.length / sd->signal.channels));
} }
} // namespace sox_io } // namespace sox_io
......
#ifndef TORCHAUDIO_SOX_IO_H
#define TORCHAUDIO_SOX_IO_H
#include <torch/script.h> #include <torch/script.h>
#include <torchaudio/csrc/sox_utils.h>
#include <torchaudio/csrc/typedefs.h> #include <torchaudio/csrc/typedefs.h>
namespace torchaudio { namespace torchaudio {
namespace sox_io { namespace sox_io {
c10::intrusive_ptr<::torchaudio::SignalInfo> get_info( c10::intrusive_ptr<torchaudio::SignalInfo> get_info(const std::string& path);
const std::string& file_name);
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
const std::string& path,
const int64_t frame_offset = 0,
const int64_t num_frames = -1,
const bool normalize = true,
const bool channels_first = true);
} // namespace sox_io } // namespace sox_io
} // namespace torchaudio } // namespace torchaudio
#endif
#include <c10/core/ScalarType.h>
#include <sox.h>
#include <torchaudio/csrc/sox_utils.h>
namespace torchaudio {
namespace sox_utils {
TensorSignal::TensorSignal(
torch::Tensor tensor_,
int64_t sample_rate_,
bool channels_first_)
: tensor(tensor_),
sample_rate(sample_rate_),
channels_first(channels_first_){};
torch::Tensor TensorSignal::getTensor() const {
return tensor;
}
int64_t TensorSignal::getSampleRate() const {
return sample_rate;
}
bool TensorSignal::getChannelsFirst() const {
return channels_first;
}
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
SoxFormat::~SoxFormat() {
if (fd_ != nullptr) {
sox_close(fd_);
}
}
sox_format_t* SoxFormat::operator->() const noexcept {
return fd_;
}
sox_format_t* SoxFormat::get() const noexcept {
return fd_;
}
void validate_input_file(const SoxFormat& sf) {
if (sf.get() == nullptr) {
throw std::runtime_error("Error loading audio file: failed to open file.");
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
if (sf->signal.length == 0) {
throw std::runtime_error("Error reading audio file: unkown length.");
}
}
caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding,
const unsigned precision) {
const auto dtype = [&]() {
switch (encoding) {
case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
return torch::kUInt8;
case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV
switch (precision) {
case 16:
return torch::kInt16;
case 32:
return torch::kInt32;
default:
throw std::runtime_error(
"Only 16 and 32 bits are supported for signed PCM.");
}
default:
// default to float32 for the other formats, including
// 32-bit flaoting-point WAV,
// MP3,
// FLAC,
// VORBIS etc...
return torch::kFloat32;
}
}();
return c10::scalarTypeToTypeMeta(dtype);
}
torch::Tensor convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
const int32_t num_channels,
const caffe2::TypeMeta dtype,
const bool normalize,
const bool channels_first) {
auto t = torch::from_blob(
buffer, {num_samples / num_channels, num_channels}, torch::kInt32);
// Note: Tensor created from_blob does not own data but borrwos
// So make sure to create a new copy after processing samples.
if (normalize || dtype == torch::kFloat32) {
t = t.to(torch::kFloat32);
t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.;
} else if (dtype == torch::kInt32) {
t = t.clone();
} else if (dtype == torch::kInt16) {
t.floor_divide_(1 << 16);
t = t.to(torch::kInt16);
} else if (dtype == torch::kUInt8) {
t.floor_divide_(1 << 24);
t += 128;
t = t.to(torch::kUInt8);
} else {
throw std::runtime_error("Unsupported dtype.");
}
if (channels_first) {
t = t.transpose(1, 0);
}
return t.contiguous();
}
} // namespace sox_utils
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_UTILS_H
#define TORCHAUDIO_SOX_UTILS_H
#include <sox.h>
#include <torch/script.h>
namespace torchaudio {
namespace sox_utils {
struct TensorSignal : torch::CustomClassHolder {
torch::Tensor tensor;
int64_t sample_rate;
bool channels_first;
TensorSignal(
torch::Tensor tensor_,
int64_t sample_rate_,
bool channels_first_);
torch::Tensor getTensor() const;
int64_t getSampleRate() const;
bool getChannelsFirst() const;
};
/// helper class to automatically close sox_format_t*
struct SoxFormat {
explicit SoxFormat(sox_format_t* fd) noexcept;
SoxFormat(const SoxFormat& other) = delete;
SoxFormat(SoxFormat&& other) = delete;
SoxFormat& operator=(const SoxFormat& other) = delete;
SoxFormat& operator=(SoxFormat&& other) = delete;
~SoxFormat();
sox_format_t* operator->() const noexcept;
sox_format_t* get() const noexcept;
private:
sox_format_t* fd_;
};
///
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf);
///
/// Get target dtype for the given encoding and precision.
caffe2::TypeMeta get_dtype(
const sox_encoding_t encoding,
const unsigned precision);
///
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
/// NOTE: This function might modify the values in the input buffer to
/// reduce the number of memory copy.
/// @param buffer Pointer to buffer that contains audio data.
/// @param num_samples The number of samples to read.
/// @param num_channels The number of channels. Used to reshape the resulting
/// Tensor.
/// @param dtype Target dtype. Determines the output dtype and value range in
/// conjunction with normalization.
/// @param noramlize Perform normalization. Only effective when dtype is not
/// kFloat32. When effective, the output tensor is kFloat32 type and value range
/// is [-1.0, 1.0]
/// @param channels_first When True, output Tensor has shape of [num_channels,
/// num_frames].
torch::Tensor convert_to_tensor(
sox_sample_t* buffer,
const int32_t num_samples,
const int32_t num_channels,
const caffe2::TypeMeta dtype,
const bool normalize,
const bool channels_first);
} // namespace sox_utils
} // namespace torchaudio
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment