Unverified Commit 463a8b2c authored by moto's avatar moto Committed by GitHub
Browse files

Support file-like object in load function (#1158)

parent 422edb18
......@@ -46,9 +46,9 @@ if [ "${os}" == Linux ] ; then
# TODO: move this to docker
apt install -y -q libsndfile1
conda install -y -c conda-forge codecov pytest pytest-cov
pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy
pip install kaldi-io 'librosa>=0.8.0' parameterized SoundFile scipy 'requests>=2.20'
else
# Note: installing librosa via pip fail because it will try to compile numba.
conda install -y -c conda-forge codecov pytest pytest-cov 'librosa>=0.8.0' parameterized scipy
pip install kaldi-io SoundFile
pip install kaldi-io SoundFile 'requests>=2.20'
fi
......@@ -8,6 +8,7 @@ from .backend_utils import (
)
from .case_utils import (
TempDirMixin,
HttpServerMixin,
TestBaseMixin,
PytorchTestCase,
TorchaudioTestCase,
......
import shutil
import os.path
import subprocess
import tempfile
import time
import unittest
import torch
......@@ -40,6 +42,32 @@ class TempDirMixin:
return path
class HttpServerMixin(TempDirMixin):
"""Mixin that serves temporary directory as web server
This class creates temporary directory and serve the directory as HTTP service.
The server is up through the execution of all the test suite defined under the subclass.
"""
_proc = None
_port = 8000
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._proc = subprocess.Popen(
['python', '-m', 'http.server', f'{cls._port}'],
cwd=cls.get_base_temp_dir())
time.sleep(1.0)
@classmethod
def tearDownClass(cls):
super().tearDownClass()
cls._proc.kill()
def get_url(self, *route):
return f'http://localhost:{self._port}/{self.id()}/{"/".join(route)}'
class TestBaseMixin:
"""Mixin to provide consistent way to define device/dtype/backend aware TestCase"""
dtype = None
......
import os
import tarfile
from unittest.mock import patch
import torch
......@@ -299,3 +300,58 @@ class TestLoadFormat(TempDirMixin, PytorchTestCase):
@skipIfFormatNotSupported("FLAC")
def test_flac(self, format_):
self._test_format(format_)
@skipIfNoModule("soundfile")
class TestFileObject(TempDirMixin, PytorchTestCase):
def _test_fileobj(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
path = self.get_temp_path(f'test.{ext}')
data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(path, data, sample_rate)
expected = soundfile.read(path, dtype='float32')[0].T
with open(path, 'rb') as fileobj:
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)
def test_fileobj_wav(self):
"""Loading audio via file-like object works"""
self._test_fileobj('wav')
@skipIfFormatNotSupported("FLAC")
def test_fileobj_flac(self):
"""Loading audio via file-like object works"""
self._test_fileobj('flac')
def _test_tarfile(self, ext):
"""Loading audio via file-like object works"""
sample_rate = 16000
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
data = get_wav_data('float32', num_channels=2).numpy().T
soundfile.write(audio_path, data, sample_rate)
expected = soundfile.read(audio_path, dtype='float32')[0].T
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = soundfile_backend.load(fileobj)
assert sr == sample_rate
self.assertEqual(expected, found)
def test_tarfile_wav(self):
"""Loading audio via file-like object works"""
self._test_tarfile('wav')
@skipIfFormatNotSupported("FLAC")
def test_tarfile_flac(self):
"""Loading audio via file-like object works"""
self._test_tarfile('flac')
import io
import itertools
import tarfile
from torchaudio.backend import sox_io_backend
from parameterized import parameterized
from torchaudio.backend import sox_io_backend
from torchaudio._internal import module_utils as _mod_utils
from torchaudio_unittest.common_utils import (
TempDirMixin,
HttpServerMixin,
PytorchTestCase,
skipIfNoExec,
skipIfNoExtension,
skipIfNoModule,
get_asset_path,
get_wav_data,
load_wav,
......@@ -19,6 +24,10 @@ from .common import (
)
if _mod_utils.is_module_available("requests"):
import requests
class LoadTestBase(TempDirMixin, PytorchTestCase):
def assert_wav(self, dtype, sample_rate, num_channels, normalize, duration):
"""`sox_io_backend.load` can load wav format correctly.
......@@ -369,3 +378,156 @@ class TestLoadWithoutExtension(PytorchTestCase):
path = get_asset_path("mp3_without_ext")
_, sr = sox_io_backend.load(path, format="mp3")
assert sr == 16000
@skipIfNoExtension
@skipIfNoExec('sox')
class TestFileObject(TempDirMixin, PytorchTestCase):
"""
In this test suite, the result of file-like object input is compared against file path input,
because `load` function is rigrously tested for file path inputs to match libsox's result,
"""
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_fileobj(self, ext, compression):
"""Loading audio via file object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as fileobj:
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_bytesio(self, ext, compression):
"""Loading audio via BytesIO object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
path = self.get_temp_path(f'test.{ext}')
sox_utils.gen_audio_file(
path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(path)
with open(path, 'rb') as file_:
fileobj = io.BytesIO(file_.read())
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_tarfile(self, ext, compression):
"""Loading compressed audio via file-like object returns the same result as via file path."""
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
archive_path = self.get_temp_path('archive.tar.gz')
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2,
compression=compression)
expected, _ = sox_io_backend.load(audio_path)
with tarfile.TarFile(archive_path, 'w') as tarobj:
tarobj.add(audio_path, arcname=audio_file)
with tarfile.TarFile(archive_path, 'r') as tarobj:
fileobj = tarobj.extractfile(audio_file)
found, sr = sox_io_backend.load(fileobj, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@skipIfNoExtension
@skipIfNoExec('sox')
@skipIfNoModule("requests")
class TestFileObjectHttp(HttpServerMixin, PytorchTestCase):
@parameterized.expand([
('wav', None),
('mp3', 128),
('mp3', 320),
('flac', 0),
('flac', 5),
('flac', 8),
('vorbis', -1),
('vorbis', 10),
('amb', None),
])
def test_requests(self, ext, compression):
sample_rate = 16000
format_ = ext if ext in ['mp3'] else None
audio_file = f'test.{ext}'
audio_path = self.get_temp_path(audio_file)
sox_utils.gen_audio_file(
audio_path, sample_rate, num_channels=2, compression=compression)
expected, _ = sox_io_backend.load(audio_path)
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, format=format_)
assert sr == sample_rate
self.assertEqual(expected, found)
@parameterized.expand(list(itertools.product(
[0, 1, 10, 100, 1000],
[-1, 1, 10, 100, 1000],
)), name_func=name_func)
def test_frame(self, frame_offset, num_frames):
"""num_frames and frame_offset correctly specify the region of data"""
sample_rate = 8000
audio_file = 'test.wav'
audio_path = self.get_temp_path(audio_file)
original = get_wav_data('float32', num_channels=2)
save_wav(audio_path, original, sample_rate)
frame_end = None if num_frames == -1 else frame_offset + num_frames
expected = original[:, frame_offset:frame_end]
url = self.get_url(audio_file)
with requests.get(url, stream=True) as resp:
found, sr = sox_io_backend.load(resp.raw, frame_offset, num_frames)
assert sr == sample_rate
self.assertEqual(expected, found)
......@@ -82,10 +82,12 @@ def load(
``[-1.0, 1.0]``.
Args:
filepath (str or pathlib.Path): Path to audio file.
This functionalso handles ``pathlib.Path`` objects, but is annotated as ``str``
for the consistency with "sox_io" backend, which has a restriction on type annotation
for TorchScript compiler compatiblity.
filepath (path-like object or file-like object):
Source of audio data.
Note:
* This argument is intentionally annotated as ``str`` only,
for the consistency with "sox_io" backend, which has a restriction
on type annotation due to TorchScript compiler compatiblity.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
......
import os
from typing import Tuple, Optional
import torch
......@@ -5,6 +6,7 @@ from torchaudio._internal import (
module_utils as _mod_utils,
)
import torchaudio
from .common import AudioMetaData
......@@ -82,9 +84,17 @@ def load(
``[-1.0, 1.0]``.
Args:
filepath (str or pathlib.Path):
Path to audio file. This function also handles ``pathlib.Path`` objects, but is
annotated as ``str`` for TorchScript compiler compatibility.
filepath (path-like object or file-like object):
Source of audio data. When the function is not compiled by TorchScript,
(e.g. ``torch.jit.script``), the following types are accepted;
* ``path-like``: file path
* ``file-like``: Object with ``read(size: int) -> bytes`` method,
which returns byte string of at most ``size`` length.
When the function is compiled by TorchScript, only ``str`` type is allowed.
Note:
* This argument is intentionally annotated as ``str`` only due to
TorchScript compiler compatibility.
frame_offset (int):
Number of frames to skip before start reading data.
num_frames (int):
......@@ -112,8 +122,13 @@ def load(
integer type, else ``float32`` type. If ``channels_first=True``, it has
``[channel, time]`` else ``[time, channel]``.
"""
# Cast to str in case type is `pathlib.Path`
filepath = str(filepath)
if not torch.jit.is_scripting():
if hasattr(filepath, 'read'):
return torchaudio._torchaudio.load_audio_fileobj(
filepath, frame_offset, num_frames, normalize, channels_first, format)
signal = torch.ops.torchaudio.sox_io_load_audio_file(
os.fspath(filepath), frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
signal = torch.ops.torchaudio.sox_io_load_audio_file(
filepath, frame_offset, num_frames, normalize, channels_first, format)
return signal.get_tensor(), signal.get_sample_rate()
......
#include <torch/extension.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/legacy.h>
PYBIND11_MODULE(_torchaudio, m) {
py::class_<sox_signalinfo_t>(m, "sox_signalinfo_t")
.def(py::init<>())
......@@ -94,4 +96,8 @@ PYBIND11_MODULE(_torchaudio, m) {
"get_info",
&torch::audio::get_info,
"Gets information about an audio file");
m.def(
"load_audio_fileobj",
&torchaudio::sox_io::load_audio_fileobj,
"Load audio from file object.");
}
......@@ -135,5 +135,88 @@ c10::intrusive_ptr<TensorSignal> apply_effects_file(
tensor, chain.getOutputSampleRate(), channels_first_);
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
py::object fileobj,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
// Streaming decoding over file-like object is tricky because libsox operates on FILE pointer.
// The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial chunk of byte string
// This will perform header-based format detection, if necessary, then fill the metadata of
// sox_format_t. Internally, sox_open_mem_read uses fmemopen, which returns FILE* which points the
// buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying buffer in a way that it
// starts with unseen data, and append the new data read from the given fileobj.
// This will trick libsox as if it keeps reading from the FILE* continuously.
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
// Using std::string and let it manage memory.
// 4096 is minimum size requried by auto_detect_format
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L40-L48
const size_t in_buffer_size = 4096;
std::string in_buffer(in_buffer_size, 'x');
auto* in_buf = const_cast<char*>(in_buffer.data());
// Fetch the header, and copy it to the buffer.
auto header = static_cast<std::string>(static_cast<py::bytes>(fileobj.attr("read")(4096)));
memcpy(static_cast<void*>(in_buf),
static_cast<void*>(const_cast<char*>(header.data())), header.length());
// Open file (this starts reading the header)
SoxFormat sf(sox_open_mem_read(
in_buf,
in_buffer_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_file(sf, /*check_length=*/false);
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(sf->signal.length);
// Create and run SoxEffectsChain
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_encodinginfo("wav", dtype, 0.));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize.value_or(true),
channels_first_);
return std::make_tuple(
tensor,
static_cast<int64_t>(chain.getOutputSampleRate()));
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_effects
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_EFFECTS_H
#define TORCHAUDIO_SOX_EFFECTS_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
......@@ -22,6 +26,17 @@ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> apply_effects_file(
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
py::object fileobj,
std::vector<std::vector<std::string>> effects,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_effects
} // namespace torchaudio
......
......@@ -198,7 +198,7 @@ void SoxEffectsChain::addInputTensor(TensorSignal* signal) {
priv->signal = signal;
priv->index = 0;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Failed to add effect: input_tensor");
throw std::runtime_error("Internal Error: Failed to add effect: input_tensor");
}
}
......@@ -207,7 +207,7 @@ void SoxEffectsChain::addOutputBuffer(
SoxEffect e(sox_create_effect(get_tensor_output_handler()));
static_cast<TensorOutputPriv*>(e->priv)->buffer = output_buffer;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Failed to add effect: output_tensor");
throw std::runtime_error("Internal Error: Failed to add effect: output_tensor");
}
}
......@@ -219,7 +219,7 @@ void SoxEffectsChain::addInputFile(sox_format_t* sf) {
sox_effect_options(e, 1, opts);
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Failed to add effect: input " << sf->filename;
stream << "Internal Error: Failed to add effect: input " << sf->filename;
throw std::runtime_error(stream.str());
}
}
......@@ -230,7 +230,7 @@ void SoxEffectsChain::addOutputFile(sox_format_t* sf) {
static_cast<FileOutputPriv*>(e->priv)->sf = sf;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Failed to add effect: output " << sf->filename;
stream << "Internal Error: Failed to add effect: output " << sf->filename;
throw std::runtime_error(stream.str());
}
}
......@@ -266,7 +266,7 @@ void SoxEffectsChain::addEffect(const std::vector<std::string> effect) {
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
std::ostringstream stream;
stream << "Failed to add effect: \"" << name;
stream << "Internal Error: Failed to add effect: \"" << name;
for (size_t i = 1; i < num_args; ++i) {
stream << " " << effect[i];
}
......@@ -283,5 +283,132 @@ int64_t SoxEffectsChain::getOutputSampleRate() {
return interm_sig_.rate;
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
namespace {
/// helper classes for passing file-like object to SoxEffectChain
struct FileObjInputPriv {
sox_format_t* sf;
py::object* fileobj;
char* buffer;
uint64_t buffer_size;
};
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
auto priv = static_cast<FileObjInputPriv *>(effp->priv);
auto sf = priv->sf;
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
auto buffer_size = priv->buffer_size;
// 1. Refresh the buffer
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't help)
// Therefore we need to align the content at the end of buffer, otherwise,
// libsox will keep reading the content beyond intended length.
//
// Before:
//
// |<--------consumed------->|<-remaining->|
// |*************************|-------------|
// ^ ftell
//
// After:
//
// |<-offset->|<-remaining->|<--new data-->|
// |**********|-------------|++++++++++++++|
// ^ ftell
const auto num_consumed = sf->tell_off;
const auto num_remain = buffer_size - num_consumed;
// 1.1. First, we fetch the data to see if there is data to fill the buffer
py::bytes chunk_ = fileobj->attr("read")(num_consumed);
const auto num_refill = py::len(chunk_);
const auto offset = buffer_size - (num_remain + num_refill);
if(num_refill > num_consumed) {
std::ostringstream message;
message << "Tried to read up to " << num_consumed << " bytes but, "
<< "recieved " << num_refill << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
// 1.2. Move the unconsumed data towards the beginning of buffer.
if (num_remain) {
auto src = static_cast<void*>(buffer + num_consumed);
auto dst = static_cast<void*>(buffer + offset);
memmove(dst, src, num_remain);
}
// 1.3. Refill the remaining buffer.
if (num_refill) {
auto chunk = static_cast<std::string>(chunk_);
auto src = static_cast<void*>(const_cast<char*>(chunk.c_str()));
auto dst = buffer + offset + num_remain;
memcpy(dst, src, num_refill);
}
// 1.4. Set the file pointer to the new offset
sf->tell_off = offset;
fseek ((FILE*)sf->fp, offset, SEEK_SET);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % effp->out_signal.channels;
// Read up to *osamp samples into obuf;
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);
return *osamp? SOX_SUCCESS : SOX_EOF;
}
sox_effect_handler_t* get_fileobj_input_handler() {
static sox_effect_handler_t handler{/*name=*/"input_fileobj_object",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/NULL,
/*drain=*/fileobj_input_drain,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileObjInputPriv)};
return &handler;
}
} // namespace
void SoxEffectsChain::addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj) {
in_sig_ = sf->signal;
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_fileobj_input_handler()));
auto priv = static_cast<FileObjInputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error("Internal Error: Failed to add effect: input fileobj");
}
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_effects_chain
} // namespace torchaudio
......@@ -4,6 +4,10 @@
#include <sox.h>
#include <torchaudio/csrc/sox/utils.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
namespace torchaudio {
namespace sox_effects_chain {
......@@ -33,6 +37,16 @@ class SoxEffectsChain {
void addEffect(const std::vector<std::string> effect);
int64_t getOutputNumChannels();
int64_t getOutputSampleRate();
#ifdef TORCH_API_INCLUDE_EXTENSION_H
void addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj);
#endif // TORCH_API_INCLUDE_EXTENSION_H
};
} // namespace sox_effects_chain
......
......@@ -49,13 +49,11 @@ c10::intrusive_ptr<SignalInfo> get_info(
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
}
c10::intrusive_ptr<TensorSignal> load_audio_file(
const std::string& path,
namespace {
std::vector<std::vector<std::string>> get_effects(
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
c10::optional<int64_t>& num_frames) {
const auto offset = frame_offset.value_or(0);
if (offset < 0) {
throw std::runtime_error(
......@@ -79,7 +77,19 @@ c10::intrusive_ptr<TensorSignal> load_audio_file(
os_offset << offset << "s";
effects.emplace_back(std::vector<std::string>{"trim", os_offset.str()});
}
return effects;
}
} // namespace
c10::intrusive_ptr<TensorSignal> load_audio_file(
const std::string& path,
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_file(
path, effects, normalize, channels_first, format);
}
......@@ -123,5 +133,21 @@ void save_audio_file(
chain.run();
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_fileobj(
fileobj, effects, normalize, channels_first, format);
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_io
} // namespace torchaudio
#ifndef TORCHAUDIO_SOX_IO_H
#define TORCHAUDIO_SOX_IO_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
......@@ -38,6 +42,18 @@ void save_audio_file(
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
const double compression = 0.);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t>& frame_offset,
c10::optional<int64_t>& num_frames,
c10::optional<bool>& normalize,
c10::optional<bool>& channels_first,
c10::optional<std::string>& format);
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_io
} // namespace torchaudio
......
......@@ -92,15 +92,15 @@ SoxFormat::operator sox_format_t*() const noexcept {
return fd_;
}
void validate_input_file(const SoxFormat& sf) {
void validate_input_file(const SoxFormat& sf, bool check_length) {
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error("Error loading audio file: failed to open file.");
}
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
throw std::runtime_error("Error loading audio file: unknown encoding.");
}
if (sf->signal.length == 0) {
throw std::runtime_error("Error reading audio file: unkown length.");
if (check_length && sf->signal.length == 0) {
throw std::runtime_error("Error reading audio file: unknown length.");
}
}
......
......@@ -67,7 +67,7 @@ struct SoxFormat {
///
/// Verify that input file is found, has known encoding, and not empty
void validate_input_file(const SoxFormat& sf);
void validate_input_file(const SoxFormat& sf, bool check_length=true);
///
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment