Commit 7ea69e61 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Abstract away AVFormatContext from StreamReader/Writer constructor (#3007)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3007

Simplify the construction of StreamReader/Writer in C++.

Currently these classes require client code to build AVFormatContext
manually. This is tedious and not user freindly.

Some client code actually uses the same helper function that
TorchAudio codebase uses.

This commit moves the helper logic inside of the constructor of
StreamReader/Writer, so that the signatures of these constructors
are easy to use and similar to Python interface.

Reviewed By: xiaohui-zhang

Differential Revision: D42662520

fbshipit-source-id: d95e5236810c48d7d9bd2d89c05d4f60a44b3ba1
parent 2f5fcf4f
...@@ -20,7 +20,6 @@ set( ...@@ -20,7 +20,6 @@ set(
stream_reader/stream_reader_binding.cpp stream_reader/stream_reader_binding.cpp
stream_reader/stream_reader_tensor_binding.cpp stream_reader/stream_reader_tensor_binding.cpp
stream_writer/stream_writer.cpp stream_writer/stream_writer.cpp
stream_writer/stream_writer_wrapper.cpp
stream_writer/stream_writer_binding.cpp stream_writer/stream_writer_binding.cpp
utils.cpp utils.cpp
) )
......
...@@ -8,8 +8,7 @@ namespace ffmpeg { ...@@ -8,8 +8,7 @@ namespace ffmpeg {
namespace { namespace {
PYBIND11_MODULE(_torchaudio_ffmpeg, m) { PYBIND11_MODULE(_torchaudio_ffmpeg, m) {
py::class_<StreamWriterFileObj, c10::intrusive_ptr<StreamWriterFileObj>>( py::class_<StreamWriterFileObj>(m, "StreamWriterFileObj")
m, "StreamWriterFileObj")
.def(py::init<py::object, const c10::optional<std::string>&, int64_t>()) .def(py::init<py::object, const c10::optional<std::string>&, int64_t>())
.def("set_metadata", &StreamWriterFileObj::set_metadata) .def("set_metadata", &StreamWriterFileObj::set_metadata)
.def("add_audio_stream", &StreamWriterFileObj::add_audio_stream) .def("add_audio_stream", &StreamWriterFileObj::add_audio_stream)
......
...@@ -28,11 +28,7 @@ StreamReaderFileObj::StreamReaderFileObj( ...@@ -28,11 +28,7 @@ StreamReaderFileObj::StreamReaderFileObj(
const c10::optional<std::map<std::string, std::string>>& option, const c10::optional<std::map<std::string, std::string>>& option,
int64_t buffer_size) int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size), false), : FileObj(fileobj_, static_cast<int>(buffer_size), false),
StreamReaderBinding(get_input_format_context( StreamReaderBinding(pAVIO, format, map2dict(option)) {}
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
map2dict(option),
pAVIO)) {}
std::map<std::string, std::string> StreamReaderFileObj::get_metadata() const { std::map<std::string, std::string> StreamReaderFileObj::get_metadata() const {
return dict2map(StreamReader::get_metadata()); return dict2map(StreamReader::get_metadata());
......
...@@ -8,10 +8,7 @@ StreamWriterFileObj::StreamWriterFileObj( ...@@ -8,10 +8,7 @@ StreamWriterFileObj::StreamWriterFileObj(
const c10::optional<std::string>& format, const c10::optional<std::string>& format,
int64_t buffer_size) int64_t buffer_size)
: FileObj(fileobj_, static_cast<int>(buffer_size), true), : FileObj(fileobj_, static_cast<int>(buffer_size), true),
StreamWriterBinding(get_output_format_context( StreamWriter(pAVIO, format) {}
static_cast<std::string>(py::str(fileobj_.attr("__str__")())),
format,
pAVIO)) {}
void StreamWriterFileObj::set_metadata( void StreamWriterFileObj::set_metadata(
const std::map<std::string, std::string>& metadata) { const std::map<std::string, std::string>& metadata) {
......
#pragma once #pragma once
#include <torchaudio/csrc/ffmpeg/pybind/typedefs.h> #include <torchaudio/csrc/ffmpeg/pybind/typedefs.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer_wrapper.h> #include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
class StreamWriterFileObj : protected FileObj, public StreamWriterBinding { class StreamWriterFileObj : private FileObj, public StreamWriter {
public: public:
StreamWriterFileObj( StreamWriterFileObj(
py::object fileobj, py::object fileobj,
......
...@@ -10,6 +10,81 @@ namespace ffmpeg { ...@@ -10,6 +10,81 @@ namespace ffmpeg {
using KeyType = StreamProcessor::KeyType; using KeyType = StreamProcessor::KeyType;
//////////////////////////////////////////////////////////////////////////////
// Initialization / resource allocations
//////////////////////////////////////////////////////////////////////////////
namespace {
AVFormatContext* get_input_format_context(
const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
AVIOContext* io_ctx) {
AVFormatContext* p = avformat_alloc_context();
TORCH_CHECK(p, "Failed to allocate AVFormatContext.");
if (io_ctx) {
p->pb = io_ctx;
}
auto* pInputFormat = [&format]() -> AVFORMAT_CONST AVInputFormat* {
if (format.has_value()) {
std::string format_str = format.value();
AVFORMAT_CONST AVInputFormat* pInput =
av_find_input_format(format_str.c_str());
TORCH_CHECK(pInput, "Unsupported device/format: \"", format_str, "\"");
return pInput;
}
return nullptr;
}();
AVDictionary* opt = get_option_dict(option);
int ret = avformat_open_input(&p, src.c_str(), pInputFormat, &opt);
clean_up_dict(opt);
TORCH_CHECK(
ret >= 0,
"Failed to open the input \"",
src,
"\" (",
av_err2string(ret),
").");
return p;
}
} // namespace
StreamReader::StreamReader(AVFormatContext* p) : pFormatContext(p) {
int ret = avformat_find_stream_info(pFormatContext, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to find stream information: ", av_err2string(ret));
processors =
std::vector<std::unique_ptr<StreamProcessor>>(pFormatContext->nb_streams);
for (int i = 0; i < pFormatContext->nb_streams; ++i) {
switch (pFormatContext->streams[i]->codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO:
case AVMEDIA_TYPE_VIDEO:
break;
default:
pFormatContext->streams[i]->discard = AVDISCARD_ALL;
}
}
}
StreamReader::StreamReader(
AVIOContext* io_ctx,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option)
: StreamReader(get_input_format_context(
"Custom Input Context",
format,
option,
io_ctx)) {}
StreamReader::StreamReader(
const std::string& src,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option)
: StreamReader(get_input_format_context(src, format, option, nullptr)) {}
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Helper methods // Helper methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
...@@ -41,28 +116,6 @@ void StreamReader::validate_src_stream_type(int i, AVMediaType type) { ...@@ -41,28 +116,6 @@ void StreamReader::validate_src_stream_type(int i, AVMediaType type) {
" stream."); " stream.");
} }
//////////////////////////////////////////////////////////////////////////////
// Initialization / resource allocations
//////////////////////////////////////////////////////////////////////////////
StreamReader::StreamReader(AVFormatInputContextPtr&& p)
: pFormatContext(std::move(p)) {
int ret = avformat_find_stream_info(pFormatContext, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to find stream information: ", av_err2string(ret));
processors =
std::vector<std::unique_ptr<StreamProcessor>>(pFormatContext->nb_streams);
for (int i = 0; i < pFormatContext->nb_streams; ++i) {
switch (pFormatContext->streams[i]->codecpar->codec_type) {
case AVMEDIA_TYPE_AUDIO:
case AVMEDIA_TYPE_VIDEO:
break;
default:
pFormatContext->streams[i]->discard = AVDISCARD_ALL;
}
}
}
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// Query methods // Query methods
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
......
#pragma once #pragma once
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/decoder.h> #include <torchaudio/csrc/ffmpeg/stream_reader/decoder.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h> #include <torchaudio/csrc/ffmpeg/stream_reader/stream_processor.h>
#include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h> #include <torchaudio/csrc/ffmpeg/stream_reader/typedefs.h>
...@@ -37,10 +38,37 @@ class StreamReader { ...@@ -37,10 +38,37 @@ class StreamReader {
/// ///
///@{ ///@{
/// @todo Introduce a constructor that takes std::string and abstracts away /// Construct StreamReader from already initialized AVFormatContext.
/// ffmpeg-native structs /// This is a low level constructor interact with FFmpeg directly.
/// One can provide custom AVFormatContext in case the other constructor
/// does not meet a requirement.
/// @param AVFormatContext An initialized AVFormatContext. StreamReader will
/// own the resources and release it at the end.
explicit StreamReader(AVFormatContext* pFormatContext);
/// Construct media processor from soruce URI.
///
/// @param src URL of source media, in the format FFmpeg can understand.
/// @param format Specifies format (such as mp4) or device (such as lavfi and
/// avfoundation)
/// @param option Custom option passed when initializing format context
/// (opening source).
explicit StreamReader(
const std::string& src,
const c10::optional<std::string>& format = {},
const c10::optional<OptionDict>& option = {});
/// Concstruct media processor from custom IO.
/// ///
explicit StreamReader(AVFormatInputContextPtr&& p); /// @param io_ctx Custom IO Context.
/// @param format Specifies format, such as mp4.
/// @param option Custom option passed when initializing format context
/// (opening source).
// TODO: Move this to wrapper class
explicit StreamReader(
AVIOContext* io_ctx,
const c10::optional<std::string>& format = {},
const c10::optional<OptionDict>& option = {});
///@} ///@}
......
...@@ -7,14 +7,6 @@ namespace ffmpeg { ...@@ -7,14 +7,6 @@ namespace ffmpeg {
namespace { namespace {
c10::intrusive_ptr<StreamReaderBinding> init(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option) {
return c10::make_intrusive<StreamReaderBinding>(
get_input_format_context(src, device, option));
}
using S = const c10::intrusive_ptr<StreamReaderBinding>&; using S = const c10::intrusive_ptr<StreamReaderBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
...@@ -26,7 +18,11 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) { ...@@ -26,7 +18,11 @@ TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
av_log_set_level(static_cast<int>(level)); av_log_set_level(static_cast<int>(level));
}); });
m.class_<StreamReaderBinding>("ffmpeg_StreamReader") m.class_<StreamReaderBinding>("ffmpeg_StreamReader")
.def(torch::init<>(init)) .def(torch::init<>([](const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option) {
return c10::make_intrusive<StreamReaderBinding>(src, device, option);
}))
.def("num_src_streams", [](S self) { return self->num_src_streams(); }) .def("num_src_streams", [](S self) { return self->num_src_streams(); })
.def("num_out_streams", [](S self) { return self->num_out_streams(); }) .def("num_out_streams", [](S self) { return self->num_out_streams(); })
.def("get_metadata", [](S self) { return self->get_metadata(); }) .def("get_metadata", [](S self) { return self->get_metadata(); })
......
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_tensor_binding.h> #include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_wrapper.h>
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace { namespace {
//////////////////////////////////////////////////////////////////////////////
// TensorIndexer
//////////////////////////////////////////////////////////////////////////////
// Helper structure to keep track of until where the decoding has happened
struct TensorIndexer {
torch::Tensor src;
size_t index = 0;
const uint8_t* data;
const size_t numel;
AVIOContextPtr pAVIO;
TensorIndexer(const torch::Tensor& src, int buffer_size);
};
static int read_function(void* opaque, uint8_t* buf, int buf_size) { static int read_function(void* opaque, uint8_t* buf, int buf_size) {
TensorIndexer* tensorobj = static_cast<TensorIndexer*>(opaque); TensorIndexer* tensorobj = static_cast<TensorIndexer*>(opaque);
...@@ -56,13 +70,6 @@ AVIOContext* get_io_context(TensorIndexer* opaque, int buffer_size) { ...@@ -56,13 +70,6 @@ AVIOContext* get_io_context(TensorIndexer* opaque, int buffer_size) {
return av_io_ctx; return av_io_ctx;
} }
std::string get_id(const torch::Tensor& src) {
std::stringstream ss;
ss << "Tensor <" << static_cast<const void*>(src.data_ptr<uint8_t>()) << ">";
return ss.str();
}
} // namespace
TensorIndexer::TensorIndexer(const torch::Tensor& src, int buffer_size) TensorIndexer::TensorIndexer(const torch::Tensor& src, int buffer_size)
: src(src), : src(src),
data([&]() -> uint8_t* { data([&]() -> uint8_t* {
...@@ -83,31 +90,38 @@ TensorIndexer::TensorIndexer(const torch::Tensor& src, int buffer_size) ...@@ -83,31 +90,38 @@ TensorIndexer::TensorIndexer(const torch::Tensor& src, int buffer_size)
numel(src.numel()), numel(src.numel()),
pAVIO(get_io_context(this, buffer_size)) {} pAVIO(get_io_context(this, buffer_size)) {}
StreamReaderTensorBinding::StreamReaderTensorBinding( //////////////////////////////////////////////////////////////////////////////
// StreamReaderTensorBinding
//////////////////////////////////////////////////////////////////////////////
// Structure to implement wrapper API around StreamReader and input Tensor
struct StreamReaderTensorBinding : protected TensorIndexer,
public StreamReaderBinding {
StreamReaderTensorBinding(
const torch::Tensor& src, const torch::Tensor& src,
const c10::optional<std::string>& device, const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option, const c10::optional<OptionDict>& option,
int buffer_size);
};
StreamReaderTensorBinding::StreamReaderTensorBinding(
const torch::Tensor& src,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option,
int buffer_size) int buffer_size)
: TensorIndexer(src, buffer_size), : TensorIndexer(src, buffer_size),
StreamReaderBinding( StreamReaderBinding(pAVIO, format, option) {}
get_input_format_context(get_id(src), device, option, pAVIO)) {}
namespace { using S = const c10::intrusive_ptr<StreamReaderTensorBinding>&;
c10::intrusive_ptr<StreamReaderTensorBinding> init( TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
const torch::Tensor& src, m.class_<StreamReaderTensorBinding>("ffmpeg_StreamReaderTensor")
.def(torch::init<>([](const torch::Tensor& src,
const c10::optional<std::string>& device, const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option, const c10::optional<OptionDict>& option,
int64_t buffer_size) { int64_t buffer_size) {
return c10::make_intrusive<StreamReaderTensorBinding>( return c10::make_intrusive<StreamReaderTensorBinding>(
src, device, option, static_cast<int>(buffer_size)); src, device, option, static_cast<int>(buffer_size));
} }))
using S = const c10::intrusive_ptr<StreamReaderTensorBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<StreamReaderTensorBinding>("ffmpeg_StreamReaderTensor")
.def(torch::init<>(init))
.def("num_src_streams", [](S self) { return self->num_src_streams(); }) .def("num_src_streams", [](S self) { return self->num_src_streams(); })
.def("num_out_streams", [](S self) { return self->num_out_streams(); }) .def("num_out_streams", [](S self) { return self->num_out_streams(); })
.def("get_metadata", [](S self) { return self->get_metadata(); }) .def("get_metadata", [](S self) { return self->get_metadata(); })
......
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_reader/stream_reader_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
// Helper structure to keep track of until where the decoding has happened
struct TensorIndexer {
torch::Tensor src;
size_t index = 0;
const uint8_t* data;
const size_t numel;
AVIOContextPtr pAVIO;
TensorIndexer(const torch::Tensor& src, int buffer_size);
};
// Structure to implement wrapper API around StreamReader, which is more
// suitable for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderTensorBinding : protected TensorIndexer,
public StreamReaderBinding {
StreamReaderTensorBinding(
const torch::Tensor& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option,
int buffer_size);
};
} // namespace ffmpeg
} // namespace torchaudio
...@@ -27,40 +27,17 @@ OutInfo convert(OutputStreamInfo osi) { ...@@ -27,40 +27,17 @@ OutInfo convert(OutputStreamInfo osi) {
} }
} // namespace } // namespace
AVFormatInputContextPtr get_input_format_context( StreamReaderBinding::StreamReaderBinding(
const std::string& src, const std::string& src,
const c10::optional<std::string>& device, const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option, const c10::optional<OptionDict>& option)
AVIOContext* io_ctx) { : StreamReader(src, format, option) {}
AVFormatContext* pFormat = avformat_alloc_context();
TORCH_CHECK(pFormat, "Failed to allocate AVFormatContext.");
if (io_ctx) {
pFormat->pb = io_ctx;
}
auto* pInput = [&]() -> AVFORMAT_CONST AVInputFormat* {
if (device.has_value()) {
std::string device_str = device.value();
AVFORMAT_CONST AVInputFormat* p =
av_find_input_format(device_str.c_str());
TORCH_CHECK(p, "Unsupported device/format: \"", device_str, "\"");
return p;
}
return nullptr;
}();
AVDictionary* opt = get_option_dict(option);
int ret = avformat_open_input(&pFormat, src.c_str(), pInput, &opt);
clean_up_dict(opt);
TORCH_CHECK(
ret >= 0,
"Failed to open the input \"" + src + "\" (" + av_err2string(ret) + ").");
return AVFormatInputContextPtr(pFormat);
}
StreamReaderBinding::StreamReaderBinding(AVFormatInputContextPtr&& p) StreamReaderBinding::StreamReaderBinding(
: StreamReader(std::move(p)) {} AVIOContext* io_ctx,
const c10::optional<std::string>& format,
const c10::optional<OptionDict>& option)
: StreamReader(io_ctx, format, option) {}
SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) { SrcInfo StreamReaderBinding::get_src_stream_info(int64_t i) {
return convert(StreamReader::get_src_stream_info(static_cast<int>(i))); return convert(StreamReader::get_src_stream_info(static_cast<int>(i)));
......
...@@ -5,13 +5,6 @@ ...@@ -5,13 +5,6 @@
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
// create format context for reading media
AVFormatInputContextPtr get_input_format_context(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option,
AVIOContext* io_ctx = nullptr);
// Because TorchScript requires c10::Dict type to pass dict, // Because TorchScript requires c10::Dict type to pass dict,
// while PyBind11 requires std::map type to pass dict, // while PyBind11 requires std::map type to pass dict,
// we duplicate the return tuple. // we duplicate the return tuple.
...@@ -67,18 +60,27 @@ using ChunkData = std::tuple<torch::Tensor, double>; ...@@ -67,18 +60,27 @@ using ChunkData = std::tuple<torch::Tensor, double>;
// suitable for Binding the code (i.e. it receives/returns pritimitves) // suitable for Binding the code (i.e. it receives/returns pritimitves)
struct StreamReaderBinding : public StreamReader, struct StreamReaderBinding : public StreamReader,
public torch::CustomClassHolder { public torch::CustomClassHolder {
explicit StreamReaderBinding(AVFormatInputContextPtr&& p); StreamReaderBinding(
const std::string& src,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option);
StreamReaderBinding(
AVIOContext* io_ctx,
const c10::optional<std::string>& device,
const c10::optional<OptionDict>& option);
SrcInfo get_src_stream_info(int64_t i); SrcInfo get_src_stream_info(int64_t i);
OutInfo get_out_stream_info(int64_t i); OutInfo get_out_stream_info(int64_t i);
int64_t process_packet( int64_t process_packet(
const c10::optional<double>& timeout = c10::optional<double>(), const c10::optional<double>& timeout = {},
const double backoff = 10.); const double backoff = 10.);
void process_all_packets(); void process_all_packets();
int64_t fill_buffer( int64_t fill_buffer(
const c10::optional<double>& timeout = c10::optional<double>(), const c10::optional<double>& timeout = {},
const double backoff = 10.); const double backoff = 10.);
std::vector<c10::optional<ChunkData>> pop_chunks(); std::vector<c10::optional<ChunkData>> pop_chunks();
......
...@@ -7,6 +7,51 @@ ...@@ -7,6 +7,51 @@
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace { namespace {
AVFormatContext* get_output_format_context(
const std::string& dst,
const c10::optional<std::string>& format,
AVIOContext* io_ctx) {
if (io_ctx) {
TORCH_CHECK(
format,
"`format` must be provided when the input is file-like object.");
}
AVFormatContext* p = nullptr;
int ret = avformat_alloc_output_context2(
&p, nullptr, format ? format.value().c_str() : nullptr, dst.c_str());
TORCH_CHECK(
ret >= 0,
"Failed to open output \"",
dst,
"\" (",
av_err2string(ret),
").");
if (io_ctx) {
p->pb = io_ctx;
p->flags |= AVFMT_FLAG_CUSTOM_IO;
}
return p;
}
} // namespace
StreamWriter::StreamWriter(AVFormatContext* p) : pFormatContext(p) {}
StreamWriter::StreamWriter(
AVIOContext* io_ctx,
const c10::optional<std::string>& format)
: StreamWriter(
get_output_format_context("Custom Output Context", format, io_ctx)) {}
StreamWriter::StreamWriter(
const std::string& dst,
const c10::optional<std::string>& format)
: StreamWriter(get_output_format_context(dst, format, nullptr)) {}
namespace {
std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) { std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) {
std::vector<std::string> ret; std::vector<std::string> ret;
if (codec->pix_fmts) { if (codec->pix_fmts) {
...@@ -77,12 +122,6 @@ std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) { ...@@ -77,12 +122,6 @@ std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) {
return ret; return ret;
} }
} // namespace
StreamWriter::StreamWriter(AVFormatOutputContextPtr&& p)
: pFormatContext(std::move(p)), streams(), pkt() {}
namespace {
void configure_audio_codec( void configure_audio_codec(
AVCodecContextPtr& ctx, AVCodecContextPtr& ctx,
int64_t sample_rate, int64_t sample_rate,
......
...@@ -31,8 +31,28 @@ class StreamWriter { ...@@ -31,8 +31,28 @@ class StreamWriter {
std::vector<OutputStream> streams; std::vector<OutputStream> streams;
AVPacketPtr pkt; AVPacketPtr pkt;
protected:
explicit StreamWriter(AVFormatContext*);
public: public:
explicit StreamWriter(AVFormatOutputContextPtr&& p); /// Construct StreamWriter from destination URI
///
/// @param dst Destination where encoded data are written.
/// @param format Specify output format. If not provided, it is guessed from
/// ``dst``.
explicit StreamWriter(
const std::string& dst,
const c10::optional<std::string>& format = {});
/// Construct StreamWriter from custom IO
///
/// @param io_ctx Custom IO.
/// @param format Specify output format.
// TODO: Move this into wrapper class.
explicit StreamWriter(
AVIOContext* io_ctx,
const c10::optional<std::string>& format);
// Non-copyable // Non-copyable
StreamWriter(const StreamWriter&) = delete; StreamWriter(const StreamWriter&) = delete;
StreamWriter& operator=(const StreamWriter&) = delete; StreamWriter& operator=(const StreamWriter&) = delete;
......
#include <torch/script.h> #include <torch/script.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer_wrapper.h> #include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio { namespace torchaudio {
namespace ffmpeg { namespace ffmpeg {
namespace { namespace {
c10::intrusive_ptr<StreamWriterBinding> init( class StreamWriterBinding : public StreamWriter,
public torch::CustomClassHolder {
public:
StreamWriterBinding(
const std::string& dst, const std::string& dst,
const c10::optional<std::string>& format) { const c10::optional<std::string>& format)
return c10::make_intrusive<StreamWriterBinding>( : StreamWriter(dst, format) {}
get_output_format_context(dst, format)); };
}
using S = const c10::intrusive_ptr<StreamWriterBinding>&; using S = const c10::intrusive_ptr<StreamWriterBinding>&;
TORCH_LIBRARY_FRAGMENT(torchaudio, m) { TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.class_<StreamWriterBinding>("ffmpeg_StreamWriter") m.class_<StreamWriterBinding>("ffmpeg_StreamWriter")
.def(torch::init<>(init)) .def(torch::init<>(
[](const std::string& dst, const c10::optional<std::string>& format) {
return c10::make_intrusive<StreamWriterBinding>(dst, format);
}))
.def( .def(
"add_audio_stream", "add_audio_stream",
[](S s, [](S s,
......
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer_wrapper.h>
namespace torchaudio {
namespace ffmpeg {
AVFormatOutputContextPtr get_output_format_context(
const std::string& dst,
const c10::optional<std::string>& format,
AVIOContext* io_ctx) {
if (io_ctx) {
TORCH_CHECK(
format,
"`format` must be provided when the input is file-like object.");
}
AVFormatContext* p = nullptr;
int ret = avformat_alloc_output_context2(
&p, nullptr, format ? format.value().c_str() : nullptr, dst.c_str());
TORCH_CHECK(
ret >= 0,
"Failed to open output \"",
dst,
"\" (",
av_err2string(ret),
").");
if (io_ctx) {
p->pb = io_ctx;
p->flags |= AVFMT_FLAG_CUSTOM_IO;
}
return AVFormatOutputContextPtr(p);
}
StreamWriterBinding::StreamWriterBinding(AVFormatOutputContextPtr&& p)
: StreamWriter(std::move(p)) {}
} // namespace ffmpeg
} // namespace torchaudio
#pragma once
#include <torchaudio/csrc/ffmpeg/stream_writer/stream_writer.h>
namespace torchaudio {
namespace ffmpeg {
// create format context for writing media
AVFormatOutputContextPtr get_output_format_context(
const std::string& dst,
const c10::optional<std::string>& format,
AVIOContext* io_ctx = nullptr);
class StreamWriterBinding : public StreamWriter,
public torch::CustomClassHolder {
public:
explicit StreamWriterBinding(AVFormatOutputContextPtr&& p);
};
} // namespace ffmpeg
} // namespace torchaudio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment