Commit 5cac8de3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Extract Encoder from OutputStream (#3104)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3104

Continuation of StreamWriter refactoring

This commit extract Encoder (+muxer) from OutputStream

Reviewed By: nateanl

Differential Revision: D43610887

fbshipit-source-id: 30a9862b1aabd5af331ce3f33a5815df1decbad1
parent 23231033
...@@ -16,6 +16,7 @@ set( ...@@ -16,6 +16,7 @@ set(
stream_reader/sink.cpp stream_reader/sink.cpp
stream_reader/stream_processor.cpp stream_reader/stream_processor.cpp
stream_reader/stream_reader.cpp stream_reader/stream_reader.cpp
stream_writer/encoder.cpp
stream_writer/output_stream.cpp stream_writer/output_stream.cpp
stream_writer/stream_writer.cpp stream_writer/stream_writer.cpp
compat.cpp compat.cpp
......
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
namespace torchaudio::io {
namespace {
AVStream* add_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
AVStream* stream = avformat_new_stream(format_ctx, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
} // namespace
Encoder::Encoder(AVFormatContext* format_ctx_, AVCodecContext* codec_ctx_)
: format_ctx(format_ctx_),
codec_ctx(codec_ctx_),
stream(add_stream(format_ctx, codec_ctx)) {}
///
/// Encode the given AVFrame data
///
/// @param frame Frame data to encode
void Encoder::encode(AVFrame* frame) {
int ret = avcodec_send_frame(codec_ctx, frame);
TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ").");
while (ret >= 0) {
ret = avcodec_receive_packet(codec_ctx, packet);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
// Note:
// av_interleaved_write_frame buffers the packets internally as needed
// to make sure the packets in the output file are properly interleaved
// in the order of increasing dts.
// https://ffmpeg.org/doxygen/3.4/group__lavf__encoding.html#ga37352ed2c63493c38219d935e71db6c1
// Passing nullptr will (forcefully) flush the queue, and this is
// necessary if users mal-configure the streams.
// Possible follow up: Add flush_buffer method?
// An alternative is to use `av_write_frame` functoin, but in that case
// client code is responsible for ordering packets, which makes it
// complicated to use StreamWriter
ret = av_interleaved_write_frame(format_ctx, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to flush packet (", av_err2string(ret), ").");
}
break;
} else {
TORCH_CHECK(
ret >= 0,
"Failed to fetch encoded packet (",
av_err2string(ret),
").");
}
// https://github.com/pytorch/audio/issues/2790
// If this is not set, the last frame is not properly saved, as
// the encoder cannot figure out when the packet should finish.
if (packet->duration == 0 && codec_ctx->codec_type == AVMEDIA_TYPE_VIDEO) {
// 1 means that 1 frame (in codec time base, which is the frame rate)
// This has to be set before av_packet_rescale_ts bellow.
packet->duration = 1;
}
av_packet_rescale_ts(packet, codec_ctx->time_base, stream->time_base);
packet->stream_index = stream->index;
ret = av_interleaved_write_frame(format_ctx, packet);
TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ").");
}
}
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
namespace torchaudio::io {
// Encoder + Muxer
class Encoder {
// Reference to the AVFormatContext (muxer)
AVFormatContext* format_ctx;
// Reference to codec context (encoder)
AVCodecContext* codec_ctx;
// Stream object
// Encoder object creates AVStream, but it will be deallocated along with
// AVFormatContext, So Encoder does not own it.
AVStream* stream;
// Temporary object used during the encoding
// Encoder owns it.
AVPacketPtr packet{};
public:
Encoder(AVFormatContext* format_ctx, AVCodecContext* codec_ctx);
void encode(AVFrame* frame);
};
} // namespace torchaudio::io
...@@ -7,147 +7,66 @@ ...@@ -7,147 +7,66 @@
namespace torchaudio::io { namespace torchaudio::io {
OutputStream::OutputStream( OutputStream::OutputStream(
AVFormatContext* format_ctx_, AVFormatContext* format_ctx,
AVStream* stream_,
AVCodecContextPtr&& codec_ctx_, AVCodecContextPtr&& codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_, std::unique_ptr<FilterGraph>&& filter_,
AVFramePtr&& src_frame_) AVFramePtr&& src_frame_)
: format_ctx(format_ctx_), : codec_ctx(std::move(codec_ctx_)),
stream(stream_), encoder(format_ctx, codec_ctx),
codec_ctx(std::move(codec_ctx_)),
filter(std::move(filter_)), filter(std::move(filter_)),
src_frame(std::move(src_frame_)), src_frame(std::move(src_frame_)),
dst_frame(), dst_frame(),
num_frames(0), num_frames(0) {}
packet() {}
AudioOutputStream::AudioOutputStream( AudioOutputStream::AudioOutputStream(
AVFormatContext* format_ctx_, AVFormatContext* format_ctx,
AVStream* stream_, AVCodecContextPtr&& codec_ctx,
AVCodecContextPtr&& codec_ctx_, std::unique_ptr<FilterGraph>&& filter,
std::unique_ptr<FilterGraph>&& filter_, AVFramePtr&& src_frame,
AVFramePtr&& src_frame_,
int64_t frame_capacity_) int64_t frame_capacity_)
: OutputStream( : OutputStream(
format_ctx_, format_ctx,
stream_, std::move(codec_ctx),
std::move(codec_ctx_), std::move(filter),
std::move(filter_), std::move(src_frame)),
std::move(src_frame_)),
frame_capacity(frame_capacity_) {} frame_capacity(frame_capacity_) {}
VideoOutputStream::VideoOutputStream( VideoOutputStream::VideoOutputStream(
AVFormatContext* format_ctx_, AVFormatContext* format_ctx,
AVStream* stream_, AVCodecContextPtr&& codec_ctx,
AVCodecContextPtr&& codec_ctx_, std::unique_ptr<FilterGraph>&& filter,
std::unique_ptr<FilterGraph>&& filter_, AVFramePtr&& src_frame,
AVFramePtr&& src_frame_,
AVBufferRefPtr&& hw_device_ctx_, AVBufferRefPtr&& hw_device_ctx_,
AVBufferRefPtr&& hw_frame_ctx_) AVBufferRefPtr&& hw_frame_ctx_)
: OutputStream( : OutputStream(
format_ctx_, format_ctx,
stream_, std::move(codec_ctx),
std::move(codec_ctx_), std::move(filter),
std::move(filter_), std::move(src_frame)),
std::move(src_frame_)),
hw_device_ctx(std::move(hw_device_ctx_)), hw_device_ctx(std::move(hw_device_ctx_)),
hw_frame_ctx(std::move(hw_frame_ctx_)) {} hw_frame_ctx(std::move(hw_frame_ctx_)) {}
namespace { void OutputStream::process_frame(AVFrame* src) {
/// if (!filter) {
/// Encode the given AVFrame data encoder.encode(src);
/// return;
/// @param frame Frame data to encode
/// @param format Output format context
/// @param stream Output stream in the output format context
/// @param codec Encoding context
/// @param packet Temporaly packet used during encoding.
void _encode(
AVFrame* frame,
AVFormatContext* format,
AVStream* stream,
AVCodecContext* codec,
AVPacket* packet) {
int ret = avcodec_send_frame(codec, frame);
TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ").");
while (ret >= 0) {
ret = avcodec_receive_packet(codec, packet);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
// Note:
// av_interleaved_write_frame buffers the packets internally as needed
// to make sure the packets in the output file are properly interleaved
// in the order of increasing dts.
// https://ffmpeg.org/doxygen/3.4/group__lavf__encoding.html#ga37352ed2c63493c38219d935e71db6c1
// Passing nullptr will (forcefully) flush the queue, and this is
// necessary if users mal-configure the streams.
// Possible follow up: Add flush_buffer method?
// An alternative is to use `av_write_frame` functoin, but in that case
// client code is responsible for ordering packets, which makes it
// complicated to use StreamWriter
ret = av_interleaved_write_frame(format, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to flush packet (", av_err2string(ret), ").");
}
break;
} else {
TORCH_CHECK(
ret >= 0,
"Failed to fetch encoded packet (",
av_err2string(ret),
").");
}
// https://github.com/pytorch/audio/issues/2790
// If this is not set, the last frame is not properly saved, as
// the encoder cannot figure out when the packet should finish.
if (packet->duration == 0 && codec->codec_type == AVMEDIA_TYPE_VIDEO) {
// 1 means that 1 frame (in codec time base, which is the frame rate)
// This has to be set before av_packet_rescale_ts bellow.
packet->duration = 1;
}
av_packet_rescale_ts(packet, codec->time_base, stream->time_base);
packet->stream_index = stream->index;
ret = av_interleaved_write_frame(format, packet);
TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ").");
} }
} int ret = filter->add_frame(src);
void _process(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVFormatContext* format,
AVStream* stream,
AVCodecContext* codec,
AVPacket* packet) {
int ret = filter->add_frame(src_frame);
while (ret >= 0) { while (ret >= 0) {
ret = filter->get_frame(dst_frame); ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) { if (ret == AVERROR_EOF) {
_encode(nullptr, format, stream, codec, packet); encoder.encode(nullptr);
} }
break; break;
} }
if (ret >= 0) { if (ret >= 0) {
_encode(dst_frame, format, stream, codec, packet); encoder.encode(dst_frame);
} }
av_frame_unref(dst_frame); av_frame_unref(dst_frame);
} }
} }
} // namespace
void OutputStream::process_frame(AVFrame* src) {
if (filter) {
_process(src, filter, dst_frame, format_ctx, stream, codec_ctx, packet);
} else {
_encode(src, format_ctx, stream, codec_ctx, packet);
}
}
void OutputStream::flush() { void OutputStream::flush() {
process_frame(nullptr); process_frame(nullptr);
} }
......
...@@ -3,16 +3,15 @@ ...@@ -3,16 +3,15 @@
#include <torch/types.h> #include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h> #include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h> #include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
namespace torchaudio::io { namespace torchaudio::io {
struct OutputStream { struct OutputStream {
// Reference to the AVFormatContext that this stream belongs to // Codec context
AVFormatContext* format_ctx;
// Stream object that OutputStream is responsible for managing
AVStream* stream;
// Codec context (encoder)
AVCodecContextPtr codec_ctx; AVCodecContextPtr codec_ctx;
// Encoder + Muxer
Encoder encoder;
// Filter for additional processing // Filter for additional processing
std::unique_ptr<FilterGraph> filter; std::unique_ptr<FilterGraph> filter;
// frame that user-provided input data is written // frame that user-provided input data is written
...@@ -21,12 +20,9 @@ struct OutputStream { ...@@ -21,12 +20,9 @@ struct OutputStream {
AVFramePtr dst_frame; AVFramePtr dst_frame;
// The number of samples written so far // The number of samples written so far
int64_t num_frames; int64_t num_frames;
// Temporary object used during the encoding
AVPacketPtr packet;
OutputStream( OutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
AVStream* stream,
AVCodecContextPtr&& codec_ctx, AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter, std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame); AVFramePtr&& src_frame);
...@@ -42,7 +38,6 @@ struct AudioOutputStream : OutputStream { ...@@ -42,7 +38,6 @@ struct AudioOutputStream : OutputStream {
AudioOutputStream( AudioOutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
AVStream* stream,
AVCodecContextPtr&& codec_ctx, AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter, std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame, AVFramePtr&& src_frame,
...@@ -59,7 +54,6 @@ struct VideoOutputStream : OutputStream { ...@@ -59,7 +54,6 @@ struct VideoOutputStream : OutputStream {
VideoOutputStream( VideoOutputStream(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
AVStream* stream,
AVCodecContextPtr&& codec_ctx, AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter, std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame, AVFramePtr&& src_frame,
......
...@@ -442,14 +442,12 @@ void StreamWriter::add_audio_stream( ...@@ -442,14 +442,12 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) { const c10::optional<std::string>& encoder_format) {
enum AVSampleFormat src_fmt = _get_src_sample_fmt(format);
AVCodecContextPtr ctx = AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_AUDIO, pFormatContext->oformat, encoder); get_codec_ctx(AVMEDIA_TYPE_AUDIO, pFormatContext->oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format); configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
open_codec(ctx, encoder_option); open_codec(ctx, encoder_option);
AVStream* stream = add_stream(ctx);
enum AVSampleFormat src_fmt = _get_src_sample_fmt(format);
std::unique_ptr<FilterGraph> filter = src_fmt == ctx->sample_fmt std::unique_ptr<FilterGraph> filter = src_fmt == ctx->sample_fmt
? std::unique_ptr<FilterGraph>(nullptr) ? std::unique_ptr<FilterGraph>(nullptr)
: _get_audio_filter(src_fmt, ctx); : _get_audio_filter(src_fmt, ctx);
...@@ -458,7 +456,6 @@ void StreamWriter::add_audio_stream( ...@@ -458,7 +456,6 @@ void StreamWriter::add_audio_stream(
AVFramePtr src_frame = get_audio_frame(src_fmt, ctx, frame_capacity); AVFramePtr src_frame = get_audio_frame(src_fmt, ctx, frame_capacity);
streams.emplace_back(std::make_unique<AudioOutputStream>( streams.emplace_back(std::make_unique<AudioOutputStream>(
pFormatContext, pFormatContext,
stream,
std::move(ctx), std::move(ctx),
std::move(filter), std::move(filter),
std::move(src_frame), std::move(src_frame),
...@@ -491,7 +488,6 @@ void StreamWriter::add_video_stream( ...@@ -491,7 +488,6 @@ void StreamWriter::add_video_stream(
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available."); "torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif #endif
}(); }();
enum AVPixelFormat src_fmt = _get_src_pixel_fmt(format);
AVCodecContextPtr ctx = AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_VIDEO, pFormatContext->oformat, encoder); get_codec_ctx(AVMEDIA_TYPE_VIDEO, pFormatContext->oformat, encoder);
...@@ -540,8 +536,8 @@ void StreamWriter::add_video_stream( ...@@ -540,8 +536,8 @@ void StreamWriter::add_video_stream(
#endif #endif
open_codec(ctx, encoder_option); open_codec(ctx, encoder_option);
AVStream* stream = add_stream(ctx);
enum AVPixelFormat src_fmt = _get_src_pixel_fmt(format);
std::unique_ptr<FilterGraph> filter = [&]() { std::unique_ptr<FilterGraph> filter = [&]() {
if (src_fmt != ctx->pix_fmt && device.type() == c10::DeviceType::CPU) { if (src_fmt != ctx->pix_fmt && device.type() == c10::DeviceType::CPU) {
return _get_video_filter(src_fmt, ctx); return _get_video_filter(src_fmt, ctx);
...@@ -557,7 +553,6 @@ void StreamWriter::add_video_stream( ...@@ -557,7 +553,6 @@ void StreamWriter::add_video_stream(
}(); }();
streams.emplace_back(std::make_unique<VideoOutputStream>( streams.emplace_back(std::make_unique<VideoOutputStream>(
pFormatContext, pFormatContext,
stream,
std::move(ctx), std::move(ctx),
std::move(filter), std::move(filter),
std::move(src_frame), std::move(src_frame),
...@@ -565,20 +560,6 @@ void StreamWriter::add_video_stream( ...@@ -565,20 +560,6 @@ void StreamWriter::add_video_stream(
std::move(hw_frame_ctx))); std::move(hw_frame_ctx)));
} }
AVStream* StreamWriter::add_stream(AVCodecContextPtr& codec_ctx) {
AVStream* stream = avformat_new_stream(pFormatContext, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
av_dict_free(&pFormatContext->metadata); av_dict_free(&pFormatContext->metadata);
for (auto const& [key, value] : metadata) { for (auto const& [key, value] : metadata) {
...@@ -653,7 +634,7 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) { ...@@ -653,7 +634,7 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) {
i); i);
TORCH_CHECK( TORCH_CHECK(
streams[i]->stream->codecpar->codec_type == type, streams[i]->codec_ctx->codec_type == type,
"Stream ", "Stream ",
i, i,
" is not ", " is not ",
......
...@@ -146,9 +146,6 @@ class StreamWriter { ...@@ -146,9 +146,6 @@ class StreamWriter {
/// @param metadata metadata. /// @param metadata metadata.
void set_metadata(const OptionDict& metadata); void set_metadata(const OptionDict& metadata);
private:
AVStream* add_stream(AVCodecContextPtr& ctx);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// Write methods // Write methods
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment