Commit 28da8b84 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add file-like object support to StreamWriter (#2648)

Summary: Pull Request resolved: https://github.com/pytorch/audio/pull/2648

Reviewed By: nateanl

Differential Revision: D38976874

Pulled By: mthrok

fbshipit-source-id: 0541dea2a633d97000b4b8609ff6b83f6b82c864
parent 76fca37a
import torch import torch
import torchaudio import torchaudio
from parameterized import parameterized from parameterized import parameterized, parameterized_class
from torchaudio_unittest.common_utils import ( from torchaudio_unittest.common_utils import (
get_asset_path, get_asset_path,
is_ffmpeg_available, is_ffmpeg_available,
...@@ -42,8 +42,42 @@ def get_video_chunk(fmt, frame_rate, *, width, height): ...@@ -42,8 +42,42 @@ def get_video_chunk(fmt, frame_rate, *, width, height):
return chunk return chunk
################################################################################
# Helper decorator and Mixin to duplicate the tests for fileobj
_media_source = parameterized_class(
("test_fileobj",),
[(False,), (True,)],
class_name_func=lambda cls, _, params: f'{cls.__name__}{"_fileobj" if params["test_fileobj"] else "_path"}',
)
class _MediaSourceMixin:
def setUp(self):
super().setUp()
self.src = None
def get_dst(self, path):
if not self.test_fileobj:
return path
if self.src is not None:
raise ValueError("get_dst can be called only once.")
self.src = open(path, "wb")
return self.src
def tearDown(self):
if self.src is not None:
self.src.flush()
self.src.close()
super().tearDown()
################################################################################
@skipIfNoFFmpeg @skipIfNoFFmpeg
class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase): @_media_source
class StreamWriterInterfaceTest(_MediaSourceMixin, TempDirMixin, TorchaudioTestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
super().setUpClass() super().setUpClass()
...@@ -55,7 +89,7 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -55,7 +89,7 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
super().tearDownClass() super().tearDownClass()
def get_dst(self, path): def get_dst(self, path):
return self.get_temp_path(path) return super().get_dst(self.get_temp_path(path))
def get_buf(self, path): def get_buf(self, path):
with open(self.get_temp_path(path), "rb") as fileobj: with open(self.get_temp_path(path), "rb") as fileobj:
...@@ -70,8 +104,8 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -70,8 +104,8 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
sample_rate = 8000 sample_rate = 8000
num_channels = 1 num_channels = 1
path = self.get_dst("test.mp3") dst = self.get_dst("test.mp3")
s = StreamWriter(path, format="mp3") s = StreamWriter(dst, format="mp3")
s.set_metadata(metadata={"artist": "torchaudio", "title": "foo"}) s.set_metadata(metadata={"artist": "torchaudio", "title": "foo"})
s.set_metadata(metadata={"title": self.id()}) s.set_metadata(metadata={"title": self.id()})
s.add_audio_stream(sample_rate, num_channels, format=src_fmt) s.add_audio_stream(sample_rate, num_channels, format=src_fmt)
...@@ -80,6 +114,7 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -80,6 +114,7 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
with s.open(): with s.open():
s.write_audio_chunk(0, chunk) s.write_audio_chunk(0, chunk)
path = self.get_temp_path("test.mp3")
tag = TinyTag.get(path) tag = TinyTag.get(path)
assert tag.artist is None assert tag.artist is None
assert tag.title == self.id() assert tag.title == self.id()
...@@ -203,6 +238,8 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase): ...@@ -203,6 +238,8 @@ class StreamWriterInterfaceTest(TempDirMixin, TorchaudioTestCase):
s.write_video_chunk(0, chunk) s.write_video_chunk(0, chunk)
# Fetch the written data # Fetch the written data
if self.test_fileobj:
dst.flush()
buf = self.get_buf(filename) buf = self.get_buf(filename)
result = torch.frombuffer(buf, dtype=torch.uint8) result = torch.frombuffer(buf, dtype=torch.uint8)
if encoder_fmt.endswith("p"): if encoder_fmt.endswith("p"):
......
...@@ -203,6 +203,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION) ...@@ -203,6 +203,7 @@ if (BUILD_TORCHAUDIO_PYTHON_EXTENSION)
ffmpeg/pybind/typedefs.cpp ffmpeg/pybind/typedefs.cpp
ffmpeg/pybind/pybind.cpp ffmpeg/pybind/pybind.cpp
ffmpeg/pybind/stream_reader.cpp ffmpeg/pybind/stream_reader.cpp
ffmpeg/pybind/stream_writer.cpp
) )
torchaudio_extension( torchaudio_extension(
_torchaudio_ffmpeg _torchaudio_ffmpeg
......
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h> #include <torchaudio/csrc/ffmpeg/pybind/stream_reader.h>
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace { namespace {
PYBIND11_MODULE(_torchaudio_ffmpeg, m) { PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamWriterFileObj, c10::intrusive_ptr<StreamWriterFileObj>>(
m, "StreamWriterFileObj")
.def(py::init<py::object, const c10::optional<std::string>&, int64_t>())
.def("set_metadata", &StreamWriterFileObj::set_metadata)
.def("add_audio_stream", &StreamWriterFileObj::add_audio_stream)
.def("add_video_stream", &StreamWriterFileObj::add_video_stream)
.def("dump_format", &StreamWriterFileObj::dump_format)
.def("open", &StreamWriterFileObj::open)
.def("write_audio_chunk", &StreamWriterFileObj::write_audio_chunk)
.def("write_video_chunk", &StreamWriterFileObj::write_video_chunk)
.def("flush", &StreamWriterFileObj::flush)
.def("close", &StreamWriterFileObj::close);
py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>( py::class_<StreamReaderFileObj, c10::intrusive_ptr<StreamReaderFileObj>>(
m, "StreamReaderFileObj") m, "StreamReaderFileObj")
.def(py::init< .def(py::init<
py::object, py::object,
const c10::optional<std::string>&, const c10::optional<std::string>&,
const c10::optional<std::map<std::string, std::string>>&, const c10::optional<OptionMap>&,
int64_t>()) int64_t>())
.def("num_src_streams", &StreamReaderFileObj::num_src_streams) .def("num_src_streams", &StreamReaderFileObj::num_src_streams)
.def("num_out_streams", &StreamReaderFileObj::num_out_streams) .def("num_out_streams", &StreamReaderFileObj::num_out_streams)
......
...@@ -27,7 +27,7 @@ StreamReaderFileObj::StreamReaderFileObj( ...@@ -27,7 +27,7 @@ StreamReaderFileObj::StreamReaderFileObj(
const c10::optional<std::string>& format, const c10::optional<std::string>& format,
const c10::optional<std::map<std::string, std::string>>& option, const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size) int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size)), : FileObj(fileobj_, static_cast<int>(buffer_size), false),
StreamReaderBinding(get_input_format_context( StreamReaderBinding(get_input_format_context(
static_cast<std::string>(py::str(fileobj_.attr("__str__")())), static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format, format,
......
#include <torchaudio/csrc/ffmpeg/pybind/stream_writer.h>
namespace torchaudio {
namespace ffmpeg {
StreamWriterFileObj::StreamWriterFileObj(
py::object fileobj_,
const c10::optional<std::string>& format,
int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size), true),
StreamWriterBinding(get_output_format_context(
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
pAVIO)) {}
void StreamWriterFileObj::set_metadata(
const std::map<std::string, std::string>& metadata) {
StreamWriter::set_metadata(map2dict(metadata));
}
void StreamWriterFileObj::add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format) {
StreamWriter::add_audio_stream(
sample_rate,
num_channels,
format,
encoder,
map2dict(encoder_option),
encoder_format);
}
void StreamWriterFileObj::add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format) {
StreamWriter::add_video_stream(
frame_rate,
width,
height,
format,
encoder,
map2dict(encoder_option),
encoder_format);
}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
class StreamWriterFileObj : protected FileObj, public StreamWriterBinding {
public:
StreamWriterFileObj(
py::object fileobj,
const c10::optional<std::string>& format,
int64_t buffer_size);
void set_metadata(const std::map<std::string, std::string>&);
void add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format);
void add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
std::string format,
const c10::optional<std::string>& encoder,
const c10::optional<std::map<std::string, std::string>>& encoder_option,
const c10::optional<std::string>& encoder_format);
};
} // namespace ffmpeg
} // namespace torchaudio
...@@ -31,6 +31,16 @@ static int read_function(void* opaque, uint8_t* buf, int buf_size) { ...@@ -31,6 +31,16 @@ static int read_function(void* opaque, uint8_t* buf, int buf_size) {
return num_read == 0 ? AVERROR_EOF : num_read; return num_read == 0 ? AVERROR_EOF : num_read;
} }
static int write_function(void* opaque, uint8_t* buf, int buf_size) {
FileObj* fileobj = static_cast<FileObj*>(opaque);
buf_size = FFMIN(buf_size, fileobj->buffer_size);
py::bytes b(reinterpret_cast<const char*>(buf), buf_size);
// TODO: check the return value to check
fileobj->fileobj.attr("write")(b);
return buf_size;
}
static int64_t seek_function(void* opaque, int64_t offset, int whence) { static int64_t seek_function(void* opaque, int64_t offset, int whence) {
// We do not know the file size. // We do not know the file size.
if (whence == AVSEEK_SIZE) { if (whence == AVSEEK_SIZE) {
...@@ -40,7 +50,17 @@ static int64_t seek_function(void* opaque, int64_t offset, int whence) { ...@@ -40,7 +50,17 @@ static int64_t seek_function(void* opaque, int64_t offset, int whence) {
return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence)); return py::cast<int64_t>(fileobj->fileobj.attr("seek")(offset, whence));
} }
AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) { AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size, bool writable) {
if (writable) {
TORCH_CHECK(
py::hasattr(opaque->fileobj, "write"),
"`write` method is not available.");
} else {
TORCH_CHECK(
py::hasattr(opaque->fileobj, "read"),
"`read` method is not available.");
}
uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size)); uint8_t* buffer = static_cast<uint8_t*>(av_malloc(buffer_size));
TORCH_CHECK(buffer, "Failed to allocate buffer."); TORCH_CHECK(buffer, "Failed to allocate buffer.");
...@@ -50,10 +70,10 @@ AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) { ...@@ -50,10 +70,10 @@ AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
AVIOContext* av_io_ctx = avio_alloc_context( AVIOContext* av_io_ctx = avio_alloc_context(
buffer, buffer,
buffer_size, buffer_size,
0, writable ? 1 : 0,
static_cast<void*>(opaque), static_cast<void*>(opaque),
&read_function, &read_function,
nullptr, writable ? &write_function : nullptr,
py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr); py::hasattr(opaque->fileobj, "seek") ? &seek_function : nullptr);
if (!av_io_ctx) { if (!av_io_ctx) {
...@@ -64,25 +84,28 @@ AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) { ...@@ -64,25 +84,28 @@ AVIOContextPtr get_io_context(FileObj* opaque, int buffer_size) {
} }
} // namespace } // namespace
FileObj::FileObj(py::object fileobj_, int buffer_size) FileObj::FileObj(py::object fileobj_, int buffer_size, bool writable)
: fileobj(fileobj_), : fileobj(fileobj_),
buffer_size(buffer_size), buffer_size(buffer_size),
pAVIO(get_io_context(this, buffer_size)) {} pAVIO(get_io_context(this, buffer_size, writable)) {}
c10::optional<OptionDict> map2dict( OptionDict map2dict(const OptionMap& src) {
const c10::optional<std::map<std::string, std::string>>& src) {
if (!src) {
return {};
}
OptionDict dict; OptionDict dict;
for (const auto& it : src.value()) { for (const auto& it : src) {
dict.insert(it.first.c_str(), it.second.c_str()); dict.insert(it.first.c_str(), it.second.c_str());
} }
return c10::optional<OptionDict>{dict}; return dict;
}
c10::optional<OptionDict> map2dict(const c10::optional<OptionMap>& src) {
if (src) {
return c10::optional<OptionDict>{map2dict(src.value())};
}
return {};
} }
std::map<std::string, std::string> dict2map(const OptionDict& src) { OptionMap dict2map(const OptionDict& src) {
std::map<std::string, std::string> ret; OptionMap ret;
for (const auto& it : src) { for (const auto& it : src) {
ret.insert({it.key(), it.value()}); ret.insert({it.key(), it.value()});
} }
......
...@@ -9,13 +9,16 @@ struct FileObj { ...@@ -9,13 +9,16 @@ struct FileObj {
py::object fileobj; py::object fileobj;
int buffer_size; int buffer_size;
AVIOContextPtr pAVIO; AVIOContextPtr pAVIO;
FileObj(py::object fileobj, int buffer_size); FileObj(py::object fileobj, int buffer_size, bool writable);
}; };
c10::optional<OptionDict> map2dict( using OptionMap = std::map<std::string, std::string>;
const c10::optional<std::map<std::string, std::string>>& src);
std::map<std::string, std::string> dict2map(const OptionDict& src); OptionDict map2dict(const OptionMap& src);
c10::optional<OptionDict> map2dict(const c10::optional<OptionMap>& src);
OptionMap dict2map(const OptionDict& src);
} // namespace ffmpeg } // namespace ffmpeg
} // namespace torchaudio } // namespace torchaudio
...@@ -485,7 +485,8 @@ void StreamWriter::open(const c10::optional<OptionDict>& option) { ...@@ -485,7 +485,8 @@ void StreamWriter::open(const c10::optional<OptionDict>& option) {
// file-like object) // file-like object)
AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat; AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat;
AVDictionary* opt = get_option_dict(option); AVDictionary* opt = get_option_dict(option);
if (!(fmt->flags & AVFMT_NOFILE)) { if (!(fmt->flags & AVFMT_NOFILE) &&
!(pFormatContext->flags & AVFMT_FLAG_CUSTOM_IO)) {
ret = avio_open2( ret = avio_open2(
&pFormatContext->pb, &pFormatContext->pb,
pFormatContext->url, pFormatContext->url,
...@@ -524,7 +525,8 @@ void StreamWriter::close() { ...@@ -524,7 +525,8 @@ void StreamWriter::close() {
// Close the file if it was not provided by client code (i.e. when not // Close the file if it was not provided by client code (i.e. when not
// file-like object) // file-like object)
AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat; AVFORMAT_CONST AVOutputFormat* fmt = pFormatContext->oformat;
if (!(fmt->flags & AVFMT_NOFILE)) { if (!(fmt->flags & AVFMT_NOFILE) &&
!(pFormatContext->flags & AVFMT_FLAG_CUSTOM_IO)) {
// avio_closep can be only applied to AVIOContext opened by avio_open // avio_closep can be only applied to AVIOContext opened by avio_open
avio_closep(&(pFormatContext->pb)); avio_closep(&(pFormatContext->pb));
} }
......
...@@ -5,7 +5,14 @@ namespace ffmpeg { ...@@ -5,7 +5,14 @@ namespace ffmpeg {
AVFormatOutputContextPtr get_output_format_context( AVFormatOutputContextPtr get_output_format_context(
const std::string& dst, const std::string& dst,
const c10::optional<std::string>& format) { const c10::optional<std::string>& format,
AVIOContext* io_ctx) {
if (io_ctx) {
TORCH_CHECK(
format,
"`format` must be provided when the input is file-like object.");
}
AVFormatContext* p = avformat_alloc_context(); AVFormatContext* p = avformat_alloc_context();
TORCH_CHECK(p, "Failed to allocate AVFormatContext."); TORCH_CHECK(p, "Failed to allocate AVFormatContext.");
...@@ -19,6 +26,11 @@ AVFormatOutputContextPtr get_output_format_context( ...@@ -19,6 +26,11 @@ AVFormatOutputContextPtr get_output_format_context(
av_err2string(ret), av_err2string(ret),
")."); ").");
if (io_ctx) {
p->pb = io_ctx;
p->flags |= AVFMT_FLAG_CUSTOM_IO;
}
return AVFormatOutputContextPtr(p); return AVFormatOutputContextPtr(p);
} }
......
...@@ -7,7 +7,8 @@ namespace ffmpeg { ...@@ -7,7 +7,8 @@ namespace ffmpeg {
// create format context for writing media // create format context for writing media
AVFormatOutputContextPtr get_output_format_context( AVFormatOutputContextPtr get_output_format_context(
const std::string& dst, const std::string& dst,
const c10::optional<std::string>& format); const c10::optional<std::string>& format,
AVIOContext* io_ctx = nullptr);
class StreamWriterBinding : public StreamWriter, class StreamWriterBinding : public StreamWriter,
public torch::CustomClassHolder { public torch::CustomClassHolder {
......
from typing import Dict, Optional from typing import Dict, Optional
import torch import torch
import torchaudio
def _format_doc(**kwargs): def _format_doc(**kwargs):
...@@ -50,7 +51,16 @@ class StreamWriter: ...@@ -50,7 +51,16 @@ class StreamWriter:
Args: Args:
dst (str): The destination where the encoded data are written. dst (str): The destination where the encoded data are written.
The supported value depends on the FFmpeg found in the system. If string-type, it must be a resource indicator that FFmpeg can
handle. The supported value depends on the FFmpeg found in the system.
If file-like object, it must support `write` method with the signature
`write(data: bytes) -> int`.
Please refer to the following for the expected signature and behavior of
`write` method.
- https://docs.python.org/3/library/io.html#io.BufferedIOBase.write
format (str or None, optional): format (str or None, optional):
Override the output format, or specify the output media device. Override the output format, or specify the output media device.
...@@ -81,14 +91,25 @@ class StreamWriter: ...@@ -81,14 +91,25 @@ class StreamWriter:
https://ffmpeg.org/ffmpeg-devices.html#Output-Devices https://ffmpeg.org/ffmpeg-devices.html#Output-Devices
Use `ffmpeg -devices` to list the values available in the current environment. Use `ffmpeg -devices` to list the values available in the current environment.
buffer_size (int):
The internal buffer size in byte. Used only when `dst` is a file-like object.
Default: `4096`.
""" """
def __init__( def __init__(
self, self,
dst: str, dst: str,
format: Optional[str] = None, format: Optional[str] = None,
buffer_size: int = 4096,
): ):
self._s = torch.classes.torchaudio.ffmpeg_StreamWriter(dst, format) if isinstance(dst, str):
self._s = torch.classes.torchaudio.ffmpeg_StreamWriter(dst, format)
elif hasattr(dst, "write"):
self._s = torchaudio._torchaudio_ffmpeg.StreamWriterFileObj(dst, format, buffer_size)
else:
raise ValueError("`dst` must be either a string or a file-like object.")
self._is_open = False self._is_open = False
@_format_common_args @_format_common_args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment