Unverified Commit 2a67fcc1 authored by moto's avatar moto Committed by GitHub
Browse files

Extract PyBind11 feature implementations (#1739)

This PR moves the code related to PyBind11 to the dedicated directory `torchaudio/csrc/pybind`.

Before, features related to PyBind11 (I/O for file-like object) was implemented in `torchaudio/csrc/sox` and the binding was defined in `torchaudio/csrc/pybind.cpp`. We used macro definition `TORCH_API_INCLUDE_EXTENSION_H` to turn on/off the feature, in addition to including/excluding `torchaudio/csrc/pybind.cpp` in the list of compiled source.

In the previous manner, in C++ example, one had to rebuild libtorchaudio separately, but by splitting them completely at compile time, it should conceptually possible to distribute libtorchaudio within torchaudio Python package and reuse it for C++ example.
parent feede97e
......@@ -87,11 +87,25 @@ endif()
# _torchaudio.so
################################################################################
if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
set(
EXTENSION_SOURCES
pybind/pybind.cpp
)
if(BUILD_SOX)
list(
APPEND
EXTENSION_SOURCES
pybind/sox/effects.cpp
pybind/sox/effects_chain.cpp
pybind/sox/io.cpp
pybind/sox/utils.cpp
)
endif()
add_library(
_torchaudio
SHARED
pybind.cpp
${LIBTORCHAUDIO_SOURCES}
${EXTENSION_SOURCES}
)
set_target_properties(_torchaudio PROPERTIES PREFIX "")
......@@ -105,10 +119,6 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
set_target_properties(_torchaudio PROPERTIES LINK_FLAGS "-undefined dynamic_lookup")
endif()
target_compile_definitions(
_torchaudio PRIVATE TORCH_API_INCLUDE_EXTENSION_H
)
if (BUILD_SOX)
target_compile_definitions(_torchaudio PRIVATE INCLUDE_SOX)
endif()
......
#include <torch/extension.h>
#ifdef INCLUDE_SOX
#include <torchaudio/csrc/sox/effects.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/pybind/sox/effects.h>
#include <torchaudio/csrc/pybind/sox/io.h>
#endif
PYBIND11_MODULE(_torchaudio, m) {
......
#include <torchaudio/csrc/pybind/sox/effects.h>
#include <torchaudio/csrc/pybind/sox/effects_chain.h>
#include <torchaudio/csrc/pybind/sox/utils.h>
using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_effects {
// Streaming decoding over file-like object is tricky because libsox operates on
// FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses fmemopen,
// which returns FILE* which points the buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// For Step 2. see `fileobj_input_drain` function in effects_chain.cpp
std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
py::object fileobj,
std::vector<std::vector<std::string>> effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) {
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
//
// For certain format (such as FLAC), libsox keeps reading the content at
// the initialization unless it reaches EOF even when the header is properly
// parsed. (Making buffer size 8192, which is way bigger than the header,
// resulted in libsox consuming all the buffer content at the time it opens
// the file.) Therefore buffer has to always contain valid data, except after
// EOF. We default to `sox_get_globals()->bufsiz`* for buffer size and we
// first check if there is enough data to fill the buffer. `read_fileobj`
// repeatedly calls `read` method until it receives the requested length of
// bytes or it reaches EOF. If we get bytes shorter than requested, that means
// the whole audio data are fetched.
//
// * This can be changed with `torchaudio.utils.sox_utils.set_buffer_size`.
auto capacity =
(sox_get_globals()->bufsiz > 256) ? sox_get_globals()->bufsiz : 256;
std::string buffer(capacity, '\0');
auto* in_buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, in_buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto in_buffer_size = (num_read > 256) ? num_read : 256;
// Open file (this starts reading the header)
// When opening a file there are two functions that can touches FILE*.
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
SoxFormat sf(sox_open_mem_read(
in_buf,
in_buffer_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(sf->signal.length);
// Create and run SoxEffectsChain
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChainPyBind chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize.value_or(true),
channels_first_);
return std::make_tuple(
tensor, static_cast<int64_t>(chain.getOutputSampleRate()));
}
} // namespace sox_effects
} // namespace torchaudio
#ifndef TORCHAUDIO_PYBIND_SOX_EFFECTS_H
#define TORCHAUDIO_PYBIND_SOX_EFFECTS_H
#include <torch/extension.h>
namespace torchaudio {
namespace sox_effects {
std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
py::object fileobj,
std::vector<std::vector<std::string>> effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format);
} // namespace sox_effects
} // namespace torchaudio
#endif
#include <sox.h>
#include <torchaudio/csrc/pybind/sox/effects_chain.h>
#include <torchaudio/csrc/pybind/sox/utils.h>
using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_effects_chain {
namespace {
/// helper classes for passing file-like object to SoxEffectChain
struct FileObjInputPriv {
sox_format_t* sf;
py::object* fileobj;
bool eof_reached;
char* buffer;
uint64_t buffer_size;
};
struct FileObjOutputPriv {
sox_format_t* sf;
py::object* fileobj;
char** buffer;
size_t* buffer_size;
};
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
auto priv = static_cast<FileObjInputPriv*>(effp->priv);
auto sf = priv->sf;
auto buffer = priv->buffer;
// 1. Refresh the buffer
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't
// help) Therefore we need to align the content at the end of buffer,
// otherwise, libsox will keep reading the content beyond intended length.
//
// Before:
//
// |<-------consumed------>|<---remaining--->|
// |***********************|-----------------|
// ^ ftell
//
// After:
//
// |<-offset->|<---remaining--->|<-new data->|
// |**********|-----------------|++++++++++++|
// ^ ftell
// NOTE:
// Do not use `sf->tell_off` here. Presumably, `tell_off` and `fseek` are
// supposed to be in sync, but there are cases (Vorbis) they are not
// in sync and `tell_off` has seemingly uninitialized value, which
// leads num_remain to be negative and cause segmentation fault
// in `memmove`.
const auto tell = ftell((FILE*)sf->fp);
if (tell < 0) {
throw std::runtime_error("Internal Error: ftell failed.");
}
const auto num_consumed = static_cast<size_t>(tell);
if (num_consumed > priv->buffer_size) {
throw std::runtime_error("Internal Error: buffer overrun.");
}
const auto num_remain = priv->buffer_size - num_consumed;
// 1.1. Fetch the data to see if there is data to fill the buffer
size_t num_refill = 0;
std::string chunk(num_consumed, '\0');
if (num_consumed && !priv->eof_reached) {
num_refill = read_fileobj(
priv->fileobj, num_consumed, const_cast<char*>(chunk.data()));
if (num_refill < num_consumed) {
priv->eof_reached = true;
}
}
const auto offset = num_consumed - num_refill;
// 1.2. Move the unconsumed data towards the beginning of buffer.
if (num_remain) {
auto src = static_cast<void*>(buffer + num_consumed);
auto dst = static_cast<void*>(buffer + offset);
memmove(dst, src, num_remain);
}
// 1.3. Refill the remaining buffer.
if (num_refill) {
auto src = static_cast<void*>(const_cast<char*>(chunk.c_str()));
auto dst = buffer + offset + num_remain;
memcpy(dst, src, num_refill);
}
// 1.4. Set the file pointer to the new offset
sf->tell_off = offset;
fseek((FILE*)sf->fp, offset, SEEK_SET);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % effp->out_signal.channels;
// Read up to *osamp samples into obuf;
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);
// Decoding is finished when fileobject is exhausted and sox can no longer
// decode a sample.
return (priv->eof_reached && !*osamp) ? SOX_EOF : SOX_SUCCESS;
}
int fileobj_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
if (*isamp) {
auto priv = static_cast<FileObjOutputPriv*>(effp->priv);
auto sf = priv->sf;
auto fp = static_cast<FILE*>(sf->fp);
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
auto buffer_size = priv->buffer_size;
// Encode chunk
auto num_samples_written = sox_write(sf, ibuf, *isamp);
fflush(fp);
// Copy the encoded chunk to python object.
fileobj->attr("write")(py::bytes(*buffer, ftell(fp)));
// Reset FILE*
sf->tell_off = 0;
fseek(fp, 0, SEEK_SET);
if (num_samples_written != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}
sox_effect_handler_t* get_fileobj_input_handler() {
static sox_effect_handler_t handler{
/*name=*/"input_fileobj_object",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/NULL,
/*drain=*/fileobj_input_drain,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileObjInputPriv)};
return &handler;
}
sox_effect_handler_t* get_fileobj_output_handler() {
static sox_effect_handler_t handler{
/*name=*/"output_fileobj_object",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/fileobj_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileObjOutputPriv)};
return &handler;
}
} // namespace
void SoxEffectsChainPyBind::addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj) {
in_sig_ = sf->signal;
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_fileobj_input_handler()));
auto priv = static_cast<FileObjInputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->eof_reached = false;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: input fileobj");
}
}
void SoxEffectsChainPyBind::addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_fileobj_output_handler()));
auto priv = static_cast<FileObjOutputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: output fileobj");
}
}
} // namespace sox_effects_chain
} // namespace torchaudio
#ifndef TORCHAUDIO_PYBIND_SOX_EFFECTS_CHAIN_H
#define TORCHAUDIO_PYBIND_SOX_EFFECTS_CHAIN_H
#include <torch/extension.h>
#include <torchaudio/csrc/sox/effects_chain.h>
namespace torchaudio {
namespace sox_effects_chain {
class SoxEffectsChainPyBind : public SoxEffectsChain {
using SoxEffectsChain::SoxEffectsChain;
public:
void addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj);
void addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj);
};
} // namespace sox_effects_chain
} // namespace torchaudio
#endif
#include <torchaudio/csrc/pybind/sox/effects.h>
#include <torchaudio/csrc/pybind/sox/effects_chain.h>
#include <torchaudio/csrc/pybind/sox/io.h>
#include <torchaudio/csrc/pybind/sox/utils.h>
#include <torchaudio/csrc/sox/io.h>
#include <torchaudio/csrc/sox/types.h>
using namespace torchaudio::sox_utils;
namespace torchaudio {
namespace sox_io {
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
py::object fileobj,
c10::optional<std::string> format) {
// Prepare in-memory file object
// When libsox opens a file, it also reads the header.
// When opening a file there are two functions that might touch FILE* (and the
// underlying buffer).
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
//
// `auto_detect_format` function only requires 256 bytes, but format-dependent
// `startread` handler might require more data. In case of vorbis, the size of
// header is unbounded, but typically 4kB maximum.
//
// "The header size is unbounded, although for streaming a rule-of-thumb of
// 4kB or less is recommended (and Xiph.Org's Vorbis encoder follows this
// suggestion)."
//
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
const int kDefaultCapacityInBytes = 4096;
auto capacity = (sox_get_globals()->bufsiz > kDefaultCapacityInBytes)
? sox_get_globals()->bufsiz
: kDefaultCapacityInBytes;
std::string buffer(capacity, '\0');
auto* buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t> frame_offset,
c10::optional<int64_t> num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_fileobj(
fileobj, effects, normalize, channels_first, format);
}
namespace {
// helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct AutoReleaseBuffer {
char* ptr;
size_t size;
AutoReleaseBuffer() : ptr(nullptr), size(0) {}
AutoReleaseBuffer(const AutoReleaseBuffer& other) = delete;
AutoReleaseBuffer(AutoReleaseBuffer&& other) = delete;
AutoReleaseBuffer& operator=(const AutoReleaseBuffer& other) = delete;
AutoReleaseBuffer& operator=(AutoReleaseBuffer&& other) = delete;
~AutoReleaseBuffer() {
if (ptr) {
free(ptr);
}
}
};
} // namespace
void save_audio_fileobj(
py::object fileobj,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
if (!format.has_value()) {
throw std::runtime_error(
"`format` is required when saving to file object.");
}
const auto filetype = format.value();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"amr-nb format only supports single channel audio.");
}
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"htk format only supports single channel audio.");
}
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"gsm format only supports single channel audio.");
}
if (sample_rate != 8000) {
throw std::runtime_error(
"gsm format only supports a sampling rate of 8kHz.");
}
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
AutoReleaseBuffer buffer;
SoxFormat sf(sox_open_memstream_write(
&buffer.ptr,
&buffer.size,
&signal_info,
&encoding_info,
filetype.c_str(),
/*oob=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open memory stream.");
}
torchaudio::sox_effects_chain::SoxEffectsChainPyBind chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
chain.run();
// Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf.close();
fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size));
}
} // namespace sox_io
} // namespace torchaudio
#ifndef TORCHAUDIO_PYBIND_SOX_IO_H
#define TORCHAUDIO_PYBIND_SOX_IO_H
#include <torch/extension.h>
namespace torchaudio {
namespace sox_io {
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
py::object fileobj,
c10::optional<std::string> format);
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t> frame_offset,
c10::optional<int64_t> num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format);
void save_audio_fileobj(
py::object fileobj,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample);
} // namespace sox_io
} // namespace torchaudio
#endif
#include <torchaudio/csrc/pybind/sox/utils.h>
namespace torchaudio {
namespace sox_utils {
uint64_t read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
}
} // namespace sox_utils
} // namespace torchaudio
#ifndef TORCHAUDIO_PYBIND_SOX_UTILS_H
#define TORCHAUDIO_PYBIND_SOX_UTILS_H
#include <torch/extension.h>
namespace torchaudio {
namespace sox_utils {
uint64_t read_fileobj(py::object* fileobj, uint64_t size, char* buffer);
} // namespace sox_utils
} // namespace torchaudio
#endif
......@@ -135,109 +135,6 @@ std::tuple<torch::Tensor, int64_t> apply_effects_file(
tensor, chain.getOutputSampleRate());
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
// Streaming decoding over file-like object is tricky because libsox operates on
// FILE pointer. The folloing is what `sox` and `play` commands do
// - file input -> FILE pointer
// - URL input -> call wget in suprocess and pipe the data -> FILE pointer
// - stdin -> FILE pointer
//
// We want to, instead, fetch byte strings chunk by chunk, consume them, and
// discard.
//
// Here is the approach
// 1. Initialize sox_format_t using sox_open_mem_read, providing the initial
// chunk of byte string
// This will perform header-based format detection, if necessary, then fill
// the metadata of sox_format_t. Internally, sox_open_mem_read uses fmemopen,
// which returns FILE* which points the buffer of the provided byte string.
// 2. Each time sox reads a chunk from the FILE*, we update the underlying
// buffer in a way that it
// starts with unseen data, and append the new data read from the given
// fileobj. This will trick libsox as if it keeps reading from the FILE*
// continuously.
// For Step 2. see `fileobj_input_drain` function in effects_chain.cpp
std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
py::object fileobj,
std::vector<std::vector<std::string>> effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) {
// Prepare the buffer used throughout the lifecycle of SoxEffectChain.
//
// For certain format (such as FLAC), libsox keeps reading the content at
// the initialization unless it reaches EOF even when the header is properly
// parsed. (Making buffer size 8192, which is way bigger than the header,
// resulted in libsox consuming all the buffer content at the time it opens
// the file.) Therefore buffer has to always contain valid data, except after
// EOF. We default to `sox_get_globals()->bufsiz`* for buffer size and we
// first check if there is enough data to fill the buffer. `read_fileobj`
// repeatedly calls `read` method until it receives the requested length of
// bytes or it reaches EOF. If we get bytes shorter than requested, that means
// the whole audio data are fetched.
//
// * This can be changed with `torchaudio.utils.sox_utils.set_buffer_size`.
auto capacity =
(sox_get_globals()->bufsiz > 256) ? sox_get_globals()->bufsiz : 256;
std::string buffer(capacity, '\0');
auto* in_buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, in_buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto in_buffer_size = (num_read > 256) ? num_read : 256;
// Open file (this starts reading the header)
// When opening a file there are two functions that can touches FILE*.
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
SoxFormat sf(sox_open_mem_read(
in_buf,
in_buffer_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
// Prepare output buffer
std::vector<sox_sample_t> out_buffer;
out_buffer.reserve(sf->signal.length);
// Create and run SoxEffectsChain
const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision);
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/sf->encoding,
/*output_encoding=*/get_tensor_encodinginfo(dtype));
chain.addInputFileObj(sf, in_buf, in_buffer_size, &fileobj);
for (const auto& effect : effects) {
chain.addEffect(effect);
}
chain.addOutputBuffer(&out_buffer);
chain.run();
// Create tensor from buffer
bool channels_first_ = channels_first.value_or(true);
auto tensor = convert_to_tensor(
/*buffer=*/out_buffer.data(),
/*num_samples=*/out_buffer.size(),
/*num_channels=*/chain.getOutputNumChannels(),
dtype,
normalize.value_or(true),
channels_first_);
return std::make_tuple(
tensor, static_cast<int64_t>(chain.getOutputSampleRate()));
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"torchaudio::sox_effects_initialize_sox_effects",
......
#ifndef TORCHAUDIO_SOX_EFFECTS_H
#define TORCHAUDIO_SOX_EFFECTS_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
......@@ -28,17 +24,6 @@ std::tuple<torch::Tensor, int64_t> apply_effects_file(
c10::optional<bool> channels_first,
const c10::optional<std::string>& format);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<torch::Tensor, int64_t> apply_effects_fileobj(
py::object fileobj,
std::vector<std::vector<std::string>> effects,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format);
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_effects
} // namespace torchaudio
......
......@@ -9,30 +9,6 @@ namespace sox_effects_chain {
namespace {
// Helper struct to safely close sox_effect_t* pointer returned by
// sox_create_effect
struct SoxEffect {
explicit SoxEffect(sox_effect_t* se) noexcept : se_(se){};
SoxEffect(const SoxEffect& other) = delete;
SoxEffect(const SoxEffect&& other) = delete;
SoxEffect& operator=(const SoxEffect& other) = delete;
SoxEffect& operator=(SoxEffect&& other) = delete;
~SoxEffect() {
if (se_ != nullptr) {
free(se_);
}
}
operator sox_effect_t*() const {
return se_;
};
sox_effect_t* operator->() noexcept {
return se_;
}
private:
sox_effect_t* se_;
};
/// helper classes for passing the location of input tensor and output buffer
///
/// drain/flow callback functions require plaing C style function signature and
......@@ -197,6 +173,22 @@ sox_effect_handler_t* get_file_output_handler() {
} // namespace
SoxEffect::SoxEffect(sox_effect_t* se) noexcept : se_(se) {}
SoxEffect::~SoxEffect() {
if (se_ != nullptr) {
free(se_);
}
}
SoxEffect::operator sox_effect_t*() const {
return se_;
}
sox_effect_t* SoxEffect::operator->() noexcept {
return se_;
}
SoxEffectsChain::SoxEffectsChain(
sox_encodinginfo_t input_encoding,
sox_encodinginfo_t output_encoding)
......@@ -327,227 +319,5 @@ int64_t SoxEffectsChain::getOutputSampleRate() {
return interm_sig_.rate;
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
namespace {
/// helper classes for passing file-like object to SoxEffectChain
struct FileObjInputPriv {
sox_format_t* sf;
py::object* fileobj;
bool eof_reached;
char* buffer;
uint64_t buffer_size;
};
struct FileObjOutputPriv {
sox_format_t* sf;
py::object* fileobj;
char** buffer;
size_t* buffer_size;
};
/// Callback function to feed byte string
/// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/sox.h#L1268-L1278
int fileobj_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
auto priv = static_cast<FileObjInputPriv*>(effp->priv);
auto sf = priv->sf;
auto buffer = priv->buffer;
// 1. Refresh the buffer
//
// NOTE:
// Since the underlying FILE* was opened with `fmemopen`, the only way
// libsox detect EOF is reaching the end of the buffer. (null byte won't
// help) Therefore we need to align the content at the end of buffer,
// otherwise, libsox will keep reading the content beyond intended length.
//
// Before:
//
// |<-------consumed------>|<---remaining--->|
// |***********************|-----------------|
// ^ ftell
//
// After:
//
// |<-offset->|<---remaining--->|<-new data->|
// |**********|-----------------|++++++++++++|
// ^ ftell
// NOTE:
// Do not use `sf->tell_off` here. Presumably, `tell_off` and `fseek` are
// supposed to be in sync, but there are cases (Vorbis) they are not
// in sync and `tell_off` has seemingly uninitialized value, which
// leads num_remain to be negative and cause segmentation fault
// in `memmove`.
const auto tell = ftell((FILE*)sf->fp);
if (tell < 0) {
throw std::runtime_error("Internal Error: ftell failed.");
}
const auto num_consumed = static_cast<size_t>(tell);
if (num_consumed > priv->buffer_size) {
throw std::runtime_error("Internal Error: buffer overrun.");
}
const auto num_remain = priv->buffer_size - num_consumed;
// 1.1. Fetch the data to see if there is data to fill the buffer
size_t num_refill = 0;
std::string chunk(num_consumed, '\0');
if (num_consumed && !priv->eof_reached) {
num_refill = read_fileobj(
priv->fileobj, num_consumed, const_cast<char*>(chunk.data()));
if (num_refill < num_consumed) {
priv->eof_reached = true;
}
}
const auto offset = num_consumed - num_refill;
// 1.2. Move the unconsumed data towards the beginning of buffer.
if (num_remain) {
auto src = static_cast<void*>(buffer + num_consumed);
auto dst = static_cast<void*>(buffer + offset);
memmove(dst, src, num_remain);
}
// 1.3. Refill the remaining buffer.
if (num_refill) {
auto src = static_cast<void*>(const_cast<char*>(chunk.c_str()));
auto dst = buffer + offset + num_remain;
memcpy(dst, src, num_refill);
}
// 1.4. Set the file pointer to the new offset
sf->tell_off = offset;
fseek((FILE*)sf->fp, offset, SEEK_SET);
// 2. Perform decoding operation
// The following part is practically same as "input" effect
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/input.c#L30-L48
// Ensure that it's a multiple of the number of channels
*osamp -= *osamp % effp->out_signal.channels;
// Read up to *osamp samples into obuf;
// store the actual number read back to *osamp
*osamp = sox_read(sf, obuf, *osamp);
// Decoding is finished when fileobject is exhausted and sox can no longer
// decode a sample.
return (priv->eof_reached && !*osamp) ? SOX_EOF : SOX_SUCCESS;
}
int fileobj_output_flow(
sox_effect_t* effp,
sox_sample_t const* ibuf,
sox_sample_t* obuf LSX_UNUSED,
size_t* isamp,
size_t* osamp) {
*osamp = 0;
if (*isamp) {
auto priv = static_cast<FileObjOutputPriv*>(effp->priv);
auto sf = priv->sf;
auto fp = static_cast<FILE*>(sf->fp);
auto fileobj = priv->fileobj;
auto buffer = priv->buffer;
auto buffer_size = priv->buffer_size;
// Encode chunk
auto num_samples_written = sox_write(sf, ibuf, *isamp);
fflush(fp);
// Copy the encoded chunk to python object.
fileobj->attr("write")(py::bytes(*buffer, ftell(fp)));
// Reset FILE*
sf->tell_off = 0;
fseek(fp, 0, SEEK_SET);
if (num_samples_written != *isamp) {
if (sf->sox_errno) {
std::ostringstream stream;
stream << sf->sox_errstr << " " << sox_strerror(sf->sox_errno) << " "
<< sf->filename;
throw std::runtime_error(stream.str());
}
return SOX_EOF;
}
}
return SOX_SUCCESS;
}
sox_effect_handler_t* get_fileobj_input_handler() {
static sox_effect_handler_t handler{
/*name=*/"input_fileobj_object",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/NULL,
/*drain=*/fileobj_input_drain,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileObjInputPriv)};
return &handler;
}
sox_effect_handler_t* get_fileobj_output_handler() {
static sox_effect_handler_t handler{
/*name=*/"output_fileobj_object",
/*usage=*/NULL,
/*flags=*/SOX_EFF_MCHAN,
/*getopts=*/NULL,
/*start=*/NULL,
/*flow=*/fileobj_output_flow,
/*drain=*/NULL,
/*stop=*/NULL,
/*kill=*/NULL,
/*priv_size=*/sizeof(FileObjOutputPriv)};
return &handler;
}
} // namespace
void SoxEffectsChain::addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj) {
in_sig_ = sf->signal;
interm_sig_ = in_sig_;
SoxEffect e(sox_create_effect(get_fileobj_input_handler()));
auto priv = static_cast<FileObjInputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->eof_reached = false;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: input fileobj");
}
}
void SoxEffectsChain::addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj) {
out_sig_ = sf->signal;
SoxEffect e(sox_create_effect(get_fileobj_output_handler()));
auto priv = static_cast<FileObjOutputPriv*>(e->priv);
priv->sf = sf;
priv->fileobj = fileobj;
priv->buffer = buffer;
priv->buffer_size = buffer_size;
if (sox_add_effect(sec_, e, &interm_sig_, &out_sig_) != SOX_SUCCESS) {
throw std::runtime_error(
"Internal Error: Failed to add effect: output fileobj");
}
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_effects_chain
} // namespace torchaudio
......@@ -4,17 +4,32 @@
#include <sox.h>
#include <torchaudio/csrc/sox/utils.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
namespace torchaudio {
namespace sox_effects_chain {
// Helper struct to safely close sox_effect_t* pointer returned by
// sox_create_effect
struct SoxEffect {
explicit SoxEffect(sox_effect_t* se) noexcept;
SoxEffect(const SoxEffect& other) = delete;
SoxEffect(const SoxEffect&& other) = delete;
SoxEffect& operator=(const SoxEffect& other) = delete;
SoxEffect& operator=(SoxEffect&& other) = delete;
~SoxEffect();
operator sox_effect_t*() const;
sox_effect_t* operator->() noexcept;
private:
sox_effect_t* se_;
};
// Helper struct to safely close sox_effects_chain_t with handy methods
class SoxEffectsChain {
const sox_encodinginfo_t in_enc_;
const sox_encodinginfo_t out_enc_;
protected:
sox_signalinfo_t in_sig_;
sox_signalinfo_t interm_sig_;
sox_signalinfo_t out_sig_;
......@@ -40,22 +55,6 @@ class SoxEffectsChain {
void addEffect(const std::vector<std::string> effect);
int64_t getOutputNumChannels();
int64_t getOutputSampleRate();
#ifdef TORCH_API_INCLUDE_EXTENSION_H
void addInputFileObj(
sox_format_t* sf,
char* buffer,
uint64_t buffer_size,
py::object* fileobj);
void addOutputFileObj(
sox_format_t* sf,
char** buffer,
size_t* buffer_size,
py::object* fileobj);
#endif // TORCH_API_INCLUDE_EXTENSION_H
};
} // namespace sox_effects_chain
......
......@@ -29,8 +29,6 @@ std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
get_encoding(sf->encoding.encoding));
}
namespace {
std::vector<std::vector<std::string>> get_effects(
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames) {
......@@ -60,8 +58,6 @@ std::vector<std::vector<std::string>> get_effects(
return effects;
}
} // namespace
std::tuple<torch::Tensor, int64_t> load_audio_file(
const std::string& path,
const c10::optional<int64_t>& frame_offset,
......@@ -133,172 +129,6 @@ void save_audio_file(
chain.run();
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
py::object fileobj,
c10::optional<std::string> format) {
// Prepare in-memory file object
// When libsox opens a file, it also reads the header.
// When opening a file there are two functions that might touch FILE* (and the
// underlying buffer).
// * `auto_detect_format`
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L43
// * `startread` handler of detected format.
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/formats.c#L574
// To see the handler of a particular format, go to
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/<FORMAT>.c
// For example, voribs can be found
// https://github.com/dmkrepo/libsox/blob/b9dd1a86e71bbd62221904e3e59dfaa9e5e72046/src/vorbis.c#L97-L158
//
// `auto_detect_format` function only requires 256 bytes, but format-dependent
// `startread` handler might require more data. In case of vorbis, the size of
// header is unbounded, but typically 4kB maximum.
//
// "The header size is unbounded, although for streaming a rule-of-thumb of
// 4kB or less is recommended (and Xiph.Org's Vorbis encoder follows this
// suggestion)."
//
// See:
// https://xiph.org/vorbis/doc/Vorbis_I_spec.html
const int kDefaultCapacityInBytes = 4096;
auto capacity = (sox_get_globals()->bufsiz > kDefaultCapacityInBytes)
? sox_get_globals()->bufsiz
: kDefaultCapacityInBytes;
std::string buffer(capacity, '\0');
auto* buf = const_cast<char*>(buffer.data());
auto num_read = read_fileobj(&fileobj, capacity, buf);
// If the file is shorter than 256, then libsox cannot read the header.
auto buf_size = (num_read > 256) ? num_read : 256;
SoxFormat sf(sox_open_mem_read(
buf,
buf_size,
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/format.has_value() ? format.value().c_str() : nullptr));
// In case of streamed data, length can be 0
validate_input_memfile(sf);
return std::make_tuple(
static_cast<int64_t>(sf->signal.rate),
static_cast<int64_t>(sf->signal.length / sf->signal.channels),
static_cast<int64_t>(sf->signal.channels),
static_cast<int64_t>(sf->encoding.bits_per_sample),
get_encoding(sf->encoding.encoding));
}
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t> frame_offset,
c10::optional<int64_t> num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format) {
auto effects = get_effects(frame_offset, num_frames);
return torchaudio::sox_effects::apply_effects_fileobj(
fileobj, effects, normalize, channels_first, format);
}
namespace {
// helper class to automatically release buffer, to be used by
// save_audio_fileobj
struct AutoReleaseBuffer {
char* ptr;
size_t size;
AutoReleaseBuffer() : ptr(nullptr), size(0) {}
AutoReleaseBuffer(const AutoReleaseBuffer& other) = delete;
AutoReleaseBuffer(AutoReleaseBuffer&& other) = delete;
AutoReleaseBuffer& operator=(const AutoReleaseBuffer& other) = delete;
AutoReleaseBuffer& operator=(AutoReleaseBuffer&& other) = delete;
~AutoReleaseBuffer() {
if (ptr) {
free(ptr);
}
}
};
} // namespace
void save_audio_fileobj(
py::object fileobj,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample) {
validate_input_tensor(tensor);
if (!format.has_value()) {
throw std::runtime_error(
"`format` is required when saving to file object.");
}
const auto filetype = format.value();
if (filetype == "amr-nb") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"amr-nb format only supports single channel audio.");
}
} else if (filetype == "htk") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"htk format only supports single channel audio.");
}
} else if (filetype == "gsm") {
const auto num_channels = tensor.size(channels_first ? 0 : 1);
if (num_channels != 1) {
throw std::runtime_error(
"gsm format only supports single channel audio.");
}
if (sample_rate != 8000) {
throw std::runtime_error(
"gsm format only supports a sampling rate of 8kHz.");
}
}
const auto signal_info =
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
const auto encoding_info = get_encodinginfo_for_save(
filetype, tensor.dtype(), compression, encoding, bits_per_sample);
AutoReleaseBuffer buffer;
SoxFormat sf(sox_open_memstream_write(
&buffer.ptr,
&buffer.size,
&signal_info,
&encoding_info,
filetype.c_str(),
/*oob=*/nullptr));
if (static_cast<sox_format_t*>(sf) == nullptr) {
throw std::runtime_error(
"Error saving audio file: failed to open memory stream.");
}
torchaudio::sox_effects_chain::SoxEffectsChain chain(
/*input_encoding=*/get_tensor_encodinginfo(tensor.dtype()),
/*output_encoding=*/sf->encoding);
chain.addInputTensor(&tensor, sample_rate, channels_first);
chain.addOutputFileObj(sf, &buffer.ptr, &buffer.size, &fileobj);
chain.run();
// Closing the sox_format_t is necessary for flushing the last chunk to the
// buffer
sf.close();
fileobj.attr("write")(py::bytes(buffer.ptr, buffer.size));
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_io_get_info", &torchaudio::sox_io::get_info_file);
m.def(
......
#ifndef TORCHAUDIO_SOX_IO_H
#define TORCHAUDIO_SOX_IO_H
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
#include <torch/script.h>
#include <torchaudio/csrc/sox/utils.h>
namespace torchaudio {
namespace sox_io {
std::vector<std::vector<std::string>> get_effects(
const c10::optional<int64_t>& frame_offset,
const c10::optional<int64_t>& num_frames);
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_file(
const std::string& path,
const c10::optional<std::string>& format);
......@@ -33,32 +33,6 @@ void save_audio_file(
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
std::tuple<int64_t, int64_t, int64_t, int64_t, std::string> get_info_fileobj(
py::object fileobj,
c10::optional<std::string> format);
std::tuple<torch::Tensor, int64_t> load_audio_fileobj(
py::object fileobj,
c10::optional<int64_t> frame_offset,
c10::optional<int64_t> num_frames,
c10::optional<bool> normalize,
c10::optional<bool> channels_first,
c10::optional<std::string> format);
void save_audio_fileobj(
py::object fileobj,
torch::Tensor tensor,
int64_t sample_rate,
bool channels_first,
c10::optional<double> compression,
c10::optional<std::string> format,
c10::optional<std::string> encoding,
c10::optional<int64_t> bits_per_sample);
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_io
} // namespace torchaudio
......
......@@ -493,35 +493,6 @@ sox_encodinginfo_t get_encodinginfo_for_save(
/*opposite_endian=*/sox_false};
}
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t read_fileobj(py::object* fileobj, const uint64_t size, char* buffer) {
uint64_t num_read = 0;
while (num_read < size) {
auto request = size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buffer, chunk.data(), chunk_len);
buffer += chunk_len;
num_read += chunk_len;
}
return num_read;
}
#endif // TORCH_API_INCLUDE_EXTENSION_H
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::sox_utils_set_seed", &torchaudio::sox_utils::set_seed);
m.def(
......
......@@ -4,10 +4,6 @@
#include <sox.h>
#include <torch/script.h>
#ifdef TORCH_API_INCLUDE_EXTENSION_H
#include <torch/extension.h>
#endif // TORCH_API_INCLUDE_EXTENSION_H
namespace torchaudio {
namespace sox_utils {
......@@ -117,12 +113,6 @@ sox_encodinginfo_t get_encodinginfo_for_save(
const c10::optional<std::string> encoding,
const c10::optional<int64_t> bits_per_sample);
#ifdef TORCH_API_INCLUDE_EXTENSION_H
uint64_t read_fileobj(py::object* fileobj, uint64_t size, char* buffer);
#endif // TORCH_API_INCLUDE_EXTENSION_H
} // namespace sox_utils
} // namespace torchaudio
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment