Commit 28da8b84 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add file-like object support to StreamWriter (#2648)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2648

Reviewed By: nateanl

Differential Revision: D38976874

Pulled By: mthrok

fbshipit-source-id: 0541dea2a633d97000b4b8609ff6b83f6b82c864
parent 76fca37a
import torch
import torchaudio
from parameterized import parameterized
from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import (
get_asset_path,
is_ffmpeg_available,
......@@ -42,8 +42,42 @@ def get_video_chunk(fmt, frame_rate, *, width, height):
return chunk
################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class(
("test_fileobj",),
[(False,), (True,)],
class_name_func=lambda cls, _, params: f'{cls.__name__}{"_fileobj" if params["test_fileobj"] else "_path"}',
)
class _MediaSourceMixin:
def setUp(self):
super().setUp()
self.src = None
def get_dst(self, path):
if not self.test_fileobj:
return path
if self.src is not None:
raise ValueError("get_dst can be called only once.")
self.src = open(path, "wb")
return self.src
def tearDown(self):
if self.src is not None:
self.src.flush()
self.src.close()
super().tearDown()
################################################################################
@skipIfNoFFmpeg
class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
@_media_source
class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
......@@ -55,7 +89,7 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
super().tearDownClass()
def get_dst(self, path):
return self.get_temp_path(path)
return super().get_dst(self.get_temp_path(path))
def get_buf(self, path):
with open(self.get_temp_path(path), "rb") as fileobj:
......@@ -70,8 +104,8 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
sample_rate = 8000
num_channels = 1
path = self.get_dst("test.mp3")
s = StreamWriter(path, format="mp3")
dst = self.get_dst("test.mp3")
s = StreamWriter(dst, format="mp3")
s.set_metadata(metadata={"artist": "torchaudio", "title": "foo"})
s.set_metadata(metadata={"title": self.id()})
s.add_audio_stream(sample_rate, num_channels, format=src_fmt)
......@@ -80,6 +114,7 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
with s.open():
s.write_audio_chunk(0, chunk)
path = self.get_temp_path("test.mp3")
tag = TinyTag.get(path)
assert tag.artist is None
assert tag.title == self.id()
......@@ -203,6 +238,8 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
s.write_video_chunk(0, chunk)
# Fetch the written data
if self.test_fileobj:
dst.flush()
buf = self.get_buf(filename)
result = torch.frombuffer(buf, dtype=torch.uint8)
if encoder_fmt.endswith("p"):
......
......@@ -203,6 +203,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
ffmpeg/pybind/typedefs.cpp
ffmpeg/pybind/pybind.cpp
ffmpeg/pybind/stream_reader.cpp
ffmpeg/pybind/stream_writer.cpp
)
torchaudio_extension(
_torchaudio_ffmpeg
......
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
namespace torchaudio {
namespace ffmpeg {
namespace {
PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamWriterFileObj, c10::intrusive_ptr<StreamWriterFileObj>>(
m, "StreamWriterFileObj")
.def(py::init<py::object, const c10::optional<std::string>&, int64_t>())
.def("set_metadata", &StreamWriterFileObj::set_metadata)
.def("add_audio_stream", &StreamWriterFileObj::add_audio_stream)
.def("add_video_stream", &StreamWriterFileObj::add_video_stream)
.def("dump_format", &StreamWriterFileObj::dump_format)
.def("open", &StreamWriterFileObj::open)
.def("write_audio_chunk", &StreamWriterFileObj::write_audio_chunk)
.def("write_video_chunk", &StreamWriterFileObj::write_video_chunk)
.def("flush", &StreamWriterFileObj::flush)
.def("close", &StreamWriterFileObj::close);
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>(
m, "StreamReaderFileObj")
.def(py::init<
py::object,
const c10::optional<std::string>&,
const c10::optional<std::map<std::string, std::string>>&,
const c10::optional<OptionMap>&,
int64_t>())
.def("num_src_streams", &StreamReaderFileObj::num_src_streams)
.def("num_out_streams", &StreamReaderFileObj::num_out_streams)
......
......@@ -27,7 +27,7 @@ StreamReaderFileObj::StreamReaderFileObj(
const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size)),
: FileObj(fileobj_, static_cast<int>(buffer_size), false),
StreamReaderBinding(get_input_format_context(
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
......
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
namespace torchaudio {
namespace ffmpeg {
StreamWriterFileObj::StreamWriterFileObj(
py::object fileobj_,
const c10::optional<std::string>& format,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size), true),
StreamWriterBinding(get_output_format_context(
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
pAVIO)) {}
void StreamWriterFileObj::set_metadata(
const std::map<std::string, std::string>& metadata) {
StreamWriter::set_metadata(map2dict(metadata));
}
void StreamWriterFileObj::add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format) {
StreamWriter::add_audio_stream(
sample_rate,
num_channels,
format,
encoder,
map2dict(encoder_option),
encoder_format);
}
void StreamWriterFileObj::add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format) {
StreamWriter::add_video_stream(
frame_rate,
width,
height,
format,
encoder,
map2dict(encoder_option),
encoder_format);
}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
class StreamWriterFileObj : protected FileObj, public StreamWriterBinding {
public:
StreamWriterFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
int64_t buffer_size);
void set_metadata(const std::map<std::string, std::string>&);
void add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format);
void add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format);
};
} // namespace ffmpeg
} // namespace torchaudio
......@@ -31,6 +31,16 @@ static int read_function(void* opaque, uint8_t* buf, int buf_size) {
return num_read == 0 ? AVERROR_EOF : num_read;
}
static int write_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
py::bytes b(reinterpret_cast<const char*>(buf), buf_size);
// TODO: check the return value to check
fileobj->fileobj.attr("write")(b);
return buf_size;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size.
if (whence == AVSEEK_SIZE) {
......@@ -40,7 +50,17 @@ static int64_t seek_function(void* opaque, int64_t offset, int whence) {
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
}
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size, bool writable) {
if (writable) {
TORCH_CHECK(
py::hasattr(opaque->fileobj, "write"),
"`write` method is not available.");
} else {
TORCH_CHECK(
py::hasattr(opaque->fileobj, "read"),
"`read` method is not available.");
}
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer.");
......@@ -50,10 +70,10 @@ AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
AVIOContext* av_io_ctx = avio_alloc_context(
buffer,
buffer_size,
0,
writable ? 1 : 0,
static_cast<void*>(opaque),
&read_function,
nullptr,
writable ? &write_function : nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);
if (!av_io_ctx) {
......@@ -64,25 +84,28 @@ AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
}
} // namespace
FileObj::FileObj(py::object fileobj_, int buffer_size)
FileObj::FileObj(py::object fileobj_, int buffer_size, bool writable)
: fileobj(fileobj_),
buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {}
pAVIO(get_io_context(this, buffer_size, writable)) {}
c10::optional<OptionDict> map2dict(
const c10::optional<std::map<std::string, std::string>>& src) {
if (!src) {
return {};
}
OptionDict map2dict(const OptionMap& src) {
OptionDict dict;
for (const auto& it : src.value()) {
for (const auto& it : src) {
dict.insert(it.first.c_str(), it.second.c_str());
}
return c10::optional<OptionDict>{dict};
return dict;
}
c10::optional<OptionDict> map2dict(const c10::optional<OptionMap>& src) {
if (src) {
return c10::optional<OptionDict>{map2dict(src.value())};
}
return {};
}
std::map<std::string, std::string> dict2map(const OptionDict& src) {
std::map<std::string, std::string> ret;
OptionMap dict2map(const OptionDict& src) {
OptionMap ret;
for (const auto& it : src) {
ret.insert({it.key(), it.value()});
}
......
......@@ -9,13 +9,16 @@ struct FileObj {
py::object fileobj;
int buffer_size;
AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size);
FileObj(py::object fileobj, int buffer_size, bool writable);
};
c10::optional<OptionDict> map2dict(
const c10::optional<std::map<std::string, std::string>>& src);
using OptionMap = std::map<std::string, std::string>;
std::map<std::string, std::string> dict2map(const OptionDict& src);
OptionDict map2dict(const OptionMap& src);
c10::optional<OptionDict> map2dict(const c10::optional<OptionMap>& src);
OptionMap dict2map(const OptionDict& src);
} // namespace ffmpeg
} // namespace torchaudio
......@@ -485,7 +485,8 @@ void StreamWriter::open(const c10::optional<OptionDict>& option) {
// file-like object)
AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat;
AVDictionary* opt = get_option_dict(option);
if (!(fmt->flags & AVFMT_NOFILE)) {
if (!(fmt->flags & AVFMT_NOFILE) &&
!(pFormatContext->flags & AVFMT_FLAG_CUSTOM_IO)) {
ret = avio_open2(
&pFormatContext->pb,
pFormatContext->url,
......@@ -524,7 +525,8 @@ void StreamWriter::close() {
// Close the file if it was not provided by client code (i.e. when not
// file-like object)
AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat;
if (!(fmt->flags & AVFMT_NOFILE)) {
if (!(fmt->flags & AVFMT_NOFILE) &&
!(pFormatContext->flags & AVFMT_FLAG_CUSTOM_IO)) {
// avio_closep can be only applied to AVIOContext opened by avio_open
avio_closep(&(pFormatContext->pb));
}
......
......@@ -5,7 +5,14 @@ namespace ffmpeg {
AVFormatOutputContextPtr get_output_format_context(
const std::string& dst,
const c10::optional<std::string>& format) {
const c10::optional<std::string>& format,
AVIOContext* io_ctx) {
if (io_ctx) {
TORCH_CHECK(
format,
"`format` must be provided when the input is file-like object.");
}
AVFormatContext* p = avformat_alloc_context();
TORCH_CHECK(p, "Failed to allocate AVFormatContext.");
......@@ -19,6 +26,11 @@ AVFormatOutputContextPtr get_output_format_context(
av_err2string(ret),
").");
if (io_ctx) {
p->pb = io_ctx;
p->flags |= AVFMT_FLAG_CUSTOM_IO;
}
return AVFormatOutputContextPtr(p);
}
......
......@@ -7,7 +7,8 @@ namespace ffmpeg {
// create format context for writing media
AVFormatOutputContextPtr get_output_format_context(
const std::string& dst,
const c10::optional<std::string>& format);
const c10::optional<std::string>& format,
AVIOContext* io_ctx = nullptr);
class StreamWriterBinding : public StreamWriter,
public torch::CustomClassHolder {
......
from typing import Dict, Optional
import torch
import torchaudio
def _format_doc(**kwargs):
......@@ -50,7 +51,16 @@ class StreamWriter:
Args:
dst (str): The destination where the encoded data are written.
The supported value depends on the FFmpeg found in the system.
If string-type, it must be a resource indicator that FFmpeg can
handle. The supported value depends on the FFmpeg found in the system.
If file-like object, it must support `write` method with the signature
`write(data: bytes) -> int`.
Please refer to the following for the expected signature and behavior of
`write` method.
- https://docs.python.org/3/library/io.html#io.BufferedIOBase.write
format (str or None, optional):
Override the output format, or specify the output media device.
......@@ -81,14 +91,25 @@ class StreamWriter:
https://ffmpeg.org/ffmpeg-devices.html#Output-Devices
Use `ffmpeg -devices` to list the values available in the current environment.
buffer_size (int):
The internal buffer size in byte. Used only when `dst` is a file-like object.
Default: `4096`.
"""
def __init__(
self,
dst: str,
format: Optional[str] = None,
buffer_size: int = 4096,
):
self._s = torch.classes.torchaudio.ffmpeg_StreamWriter(dst, format)
if isinstance(dst, str):
self._s = torch.classes.torchaudio.ffmpeg_StreamWriter(dst, format)
elif hasattr(dst, "write"):
self._s = torchaudio._torchaudio_ffmpeg.StreamWriterFileObj(dst, format, buffer_size)
else:
raise ValueError("`dst` must be either a string or a file-like object.")
self._is_open = False
@_format_common_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment