Commit b374cc7b authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Move FileObj to dedicated source (#2427)

Summary:
Extract from https://github.com/pytorch/audio/issues/2419. Move the `FileObj` definition to dedicated file, so that it can be reused from files other than StreamReader.

Pull Request resolved: https://github.com/pytorch/audio/pull/2427

Reviewed By: carolineechen

Differential Revision: D36794367

Pulled By: mthrok

fbshipit-source-id: 999658f3f4d833566d933c9223e7a5d49d300574
parent b56f60bf
......@@ -288,6 +288,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
if(USE_FFMPEG)
set(
FFMPEG_EXTENSION_SOURCES
ffmpeg/pybind/typedefs.cpp
ffmpeg/pybind/pybind.cpp
ffmpeg/pybind/stream_reader.cpp
)
......
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += chunk_len;
}
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
if (!buffer) {
throw std::runtime_error("Failed to allocate buffer.");
}
// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);
if (!av_io_ctx) {
av_freep(&buffer);
throw std::runtime_error("Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}
} // namespace
FileObj::FileObj(py::object fileobj_, int buffer_size)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}
StreamReaderFileObj::StreamReaderFileObj(
py::object fileobj_,
......
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_reader_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
};
// The reason we inherit FileObj instead of making it an attribute
// is so that FileObj is instantiated first.
// AVIOContext must be initialized before AVFormat, and outlive AVFormat.
......
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
static int read_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
int num_read = 0;
while (num_read < buf_size) {
int request = buf_size - num_read;
auto chunk = static_cast<std::string>(
static_cast<py::bytes>(fileobj->fileobj.attr("read")(request)));
auto chunk_len = chunk.length();
if (chunk_len == 0) {
break;
}
if (chunk_len > request) {
std::ostringstream message;
message
<< "Requested up to " << request << " bytes but, "
<< "received " << chunk_len << " bytes. "
<< "The given object does not confirm to read protocol of file object.";
throw std::runtime_error(message.str());
}
memcpy(buf, chunk.data(), chunk_len);
buf += chunk_len;
num_read += static_cast<int>(chunk_len);
}
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
return AVERROR(EIO);
}
FileObj* fileobj = static_cast<FileObj*>(opaque);
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
if (!buffer) {
throw std::runtime_error("Failed to allocate buffer.");
}
// If avio_alloc_context succeeds, then buffer will be cleaned up by
// AVIOContextPtr destructor.
// If avio_alloc_context fails, we need to clean up by ourselves.
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
static_cast<void*>(opaque),
&read_function,
nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);
if (!av_io_ctx) {
av_freep(&buffer);
throw std::runtime_error("Failed to allocate AVIO context.");
}
return AVIOContextPtr{av_io_ctx};
}
} // namespace
FileObj::FileObj(py::object fileobj_, int buffer_size)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
namespace torchaudio {
namespace ffmpeg {
struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
};
} // namespace ffmpeg
} // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment