Commit 8d7268f1 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor StreamReader/Writer PyBinding (#3296)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/3296

Reviewed By: hwangjeff

Differential Revision: D45503774

fbshipit-source-id: 806c22bd0f54fd0cea43d61ef3dbedd67ffeb012
parent 007cca23
......@@ -42,7 +42,6 @@ torchaudio_library(
if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
set(
ext_sources
pybind/fileobj.cpp
pybind/pybind.cpp
)
torchaudio_extension(
......
#include <torchaudio/csrc/ffmpeg/pybind/fileobj.h>
namespace torchaudio {
namespace io {
namespace {
static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
TORCH_CHECK(
chunk_len <= request,
"Requested up to ",
request,
" bytes but, received ",
chunk_len,
" bytes. The given object does not confirm to read protocol of file object.");
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += static_cast<int>(chunk_len);
}
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int write_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
py::bytes b(reinterpret_cast<const char*>(buf), buf_size);
// TODO: check the return value to check
fileobj->fileobj.attr("write")(b);
return buf_size;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size, bool writable) {
if (writable) {
TORCH_CHECK(
py::hasattr(opaque->fileobj, "write"),
"`write` method is not available.");
} else {
TORCH_CHECK(
py::hasattr(opaque->fileobj, "read"),
"`read` method is not available.");
}
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer.");
// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
writable ? 1 : 0,
static_cast<void*>(opaque),
&read_function,
writable ? &write_function : nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);
if (!av_io_ctx) {
av_freep(&buffer);
TORCH_CHECK(false, "Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}
} // namespace
FileObj::FileObj(py::object fileobj_, int buffer_size, bool writable)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size, writable)) {}
} // namespace io
} // namespace torchaudio
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio {
namespace io {
struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size, bool writable);
};
} // namespace io
} // namespace torchaudio
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/hw_context.h>
#include <torchaudio/csrc/ffmpeg/pybind/fileobj.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
......@@ -99,26 +98,93 @@ std::string get_build_config() {
return avcodec_configuration();
}
// The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat.
struct StreamReaderFileObj : private FileObj, public StreamReader {
//////////////////////////////////////////////////////////////////////////////
// StreamReader/Writer FileObj
//////////////////////////////////////////////////////////////////////////////
struct FileObj {
py::object fileobj;
int buffer_size;
};
namespace {
static int read_func(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
TORCH_CHECK(
chunk_len <= request,
"Requested up to ",
request,
" bytes but, received ",
chunk_len,
" bytes. The given object does not confirm to read protocol of file object.");
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += static_cast<int>(chunk_len);
}
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int write_func(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
py::bytes b(reinterpret_cast<const char*>(buf), buf_size);
// TODO: check the return value
fileobj->fileobj.attr("write")(b);
return buf_size;
}
static int64_t seek_func(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
} // namespace
struct StreamReaderFileObj : private FileObj, public StreamReaderCustomIO {
StreamReaderFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: FileObj(fileobj, static_cast<int>(buffer_size), false),
StreamReader(pAVIO, format, option) {}
int buffer_size)
: FileObj{fileobj, buffer_size},
StreamReaderCustomIO(
this,
format,
buffer_size,
read_func,
py::hasattr(fileobj, "seek") ? &seek_func : nullptr,
option) {}
};
struct StreamWriterFileObj : private FileObj, public StreamWriter {
struct StreamWriterFileObj : private FileObj, public StreamWriterCustomIO {
StreamWriterFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
int64_t buffer_size)
: FileObj(fileobj, static_cast<int>(buffer_size), true),
StreamWriter(pAVIO, format) {}
int buffer_size)
: FileObj{fileobj, buffer_size},
StreamWriterCustomIO(
this,
format,
buffer_size,
write_func,
py::hasattr(fileobj, "seek") ? &seek_func : nullptr) {}
};
PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment