Commit 5cac8de3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Extract Encoder from OutputStream (#3104)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3104

Continuation of StreamWriter refactoring

This commit extract Encoder (+muxer) from OutputStream

Reviewed By: nateanl

Differential Revision: D43610887

fbshipit-source-id: 30a9862b1aabd5af331ce3f33a5815df1decbad1
parent 23231033
......@@ -16,6 +16,7 @@ set(
stream_reader/sink.cpp
stream_reader/stream_processor.cpp
stream_reader/stream_reader.cpp
stream_writer/encoder.cpp
stream_writer/output_stream.cpp
stream_writer/stream_writer.cpp
compat.cpp
......
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
namespace torchaudio::io {
namespace {
AVStream* add_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
AVStream* stream = avformat_new_stream(format_ctx, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
} // namespace
Encoder::Encoder(AVFormatContext* format_ctx_, AVCodecContext* codec_ctx_)
: format_ctx(format_ctx_),
codec_ctx(codec_ctx_),
stream(add_stream(format_ctx, codec_ctx)) {}
///
/// Encode the given AVFrame data
///
/// @param frame Frame data to encode
void Encoder::encode(AVFrame* frame) {
int ret = avcodec_send_frame(codec_ctx, frame);
TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ").");
while (ret >= 0) {
ret = avcodec_receive_packet(codec_ctx, packet);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
// Note:
// av_interleaved_write_frame buffers the packets internally as needed
// to make sure the packets in the output file are properly interleaved
// in the order of increasing dts.
// https://ffmpeg.org/doxygen/3.4/group__lavf__encoding.html#ga37352ed2c63493c38219d935e71db6c1
// Passing nullptr will (forcefully) flush the queue, and this is
// necessary if users mal-configure the streams.
// Possible follow up: Add flush_buffer method?
// An alternative is to use `av_write_frame` functoin, but in that case
// client code is responsible for ordering packets, which makes it
// complicated to use StreamWriter
ret = av_interleaved_write_frame(format_ctx, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to flush packet (", av_err2string(ret), ").");
}
break;
} else {
TORCH_CHECK(
ret >= 0,
"Failed to fetch encoded packet (",
av_err2string(ret),
").");
}
// https://github.com/pytorch/audio/issues/2790
// If this is not set, the last frame is not properly saved, as
// the encoder cannot figure out when the packet should finish.
if (packet->duration == 0 && codec_ctx->codec_type == AVMEDIA_TYPE_VIDEO) {
// 1 means that 1 frame (in codec time base, which is the frame rate)
// This has to be set before av_packet_rescale_ts bellow.
packet->duration = 1;
}
av_packet_rescale_ts(packet, codec_ctx->time_base, stream->time_base);
packet->stream_index = stream->index;
ret = av_interleaved_write_frame(format_ctx, packet);
TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ").");
}
}
} // namespace torchaudio::io
#pragma once
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
namespace torchaudio::io {
// Encoder + Muxer
class Encoder {
// Reference to the AVFormatContext (muxer)
AVFormatContext* format_ctx;
// Reference to codec context (encoder)
AVCodecContext* codec_ctx;
// Stream object
// Encoder object creates AVStream, but it will be deallocated along with
// AVFormatContext, So Encoder does not own it.
AVStream* stream;
// Temporary object used during the encoding
// Encoder owns it.
AVPacketPtr packet{};
public:
Encoder(AVFormatContext* format_ctx, AVCodecContext* codec_ctx);
void encode(AVFrame* frame);
};
} // namespace torchaudio::io
......@@ -7,147 +7,66 @@
namespace torchaudio::io {
OutputStream::OutputStream(
AVFormatContext* format_ctx_,
AVStream* stream_,
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_,
AVFramePtr&& src_frame_)
: format_ctx(format_ctx_),
stream(stream_),
codec_ctx(std::move(codec_ctx_)),
: codec_ctx(std::move(codec_ctx_)),
encoder(format_ctx, codec_ctx),
filter(std::move(filter_)),
src_frame(std::move(src_frame_)),
dst_frame(),
num_frames(0),
packet() {}
num_frames(0) {}
AudioOutputStream::AudioOutputStream(
AVFormatContext* format_ctx_,
AVStream* stream_,
AVCodecContextPtr&& codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_,
AVFramePtr&& src_frame_,
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
int64_t frame_capacity_)
: OutputStream(
format_ctx_,
stream_,
std::move(codec_ctx_),
std::move(filter_),
std::move(src_frame_)),
format_ctx,
std::move(codec_ctx),
std::move(filter),
std::move(src_frame)),
frame_capacity(frame_capacity_) {}
VideoOutputStream::VideoOutputStream(
AVFormatContext* format_ctx_,
AVStream* stream_,
AVCodecContextPtr&& codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_,
AVFramePtr&& src_frame_,
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
AVBufferRefPtr&& hw_device_ctx_,
AVBufferRefPtr&& hw_frame_ctx_)
: OutputStream(
format_ctx_,
stream_,
std::move(codec_ctx_),
std::move(filter_),
std::move(src_frame_)),
format_ctx,
std::move(codec_ctx),
std::move(filter),
std::move(src_frame)),
hw_device_ctx(std::move(hw_device_ctx_)),
hw_frame_ctx(std::move(hw_frame_ctx_)) {}
namespace {
///
/// Encode the given AVFrame data
///
/// @param frame Frame data to encode
/// @param format Output format context
/// @param stream Output stream in the output format context
/// @param codec Encoding context
/// @param packet Temporaly packet used during encoding.
void _encode(
AVFrame* frame,
AVFormatContext* format,
AVStream* stream,
AVCodecContext* codec,
AVPacket* packet) {
int ret = avcodec_send_frame(codec, frame);
TORCH_CHECK(ret >= 0, "Failed to encode frame (", av_err2string(ret), ").");
while (ret >= 0) {
ret = avcodec_receive_packet(codec, packet);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
// Note:
// av_interleaved_write_frame buffers the packets internally as needed
// to make sure the packets in the output file are properly interleaved
// in the order of increasing dts.
// https://ffmpeg.org/doxygen/3.4/group__lavf__encoding.html#ga37352ed2c63493c38219d935e71db6c1
// Passing nullptr will (forcefully) flush the queue, and this is
// necessary if users mal-configure the streams.
// Possible follow up: Add flush_buffer method?
// An alternative is to use `av_write_frame` functoin, but in that case
// client code is responsible for ordering packets, which makes it
// complicated to use StreamWriter
ret = av_interleaved_write_frame(format, nullptr);
TORCH_CHECK(
ret >= 0, "Failed to flush packet (", av_err2string(ret), ").");
}
break;
} else {
TORCH_CHECK(
ret >= 0,
"Failed to fetch encoded packet (",
av_err2string(ret),
").");
}
// https://github.com/pytorch/audio/issues/2790
// If this is not set, the last frame is not properly saved, as
// the encoder cannot figure out when the packet should finish.
if (packet->duration == 0 && codec->codec_type == AVMEDIA_TYPE_VIDEO) {
// 1 means that 1 frame (in codec time base, which is the frame rate)
// This has to be set before av_packet_rescale_ts bellow.
packet->duration = 1;
}
av_packet_rescale_ts(packet, codec->time_base, stream->time_base);
packet->stream_index = stream->index;
ret = av_interleaved_write_frame(format, packet);
TORCH_CHECK(ret >= 0, "Failed to write packet (", av_err2string(ret), ").");
void OutputStream::process_frame(AVFrame* src) {
if (!filter) {
encoder.encode(src);
return;
}
}
void _process(
AVFrame* src_frame,
std::unique_ptr<FilterGraph>& filter,
AVFrame* dst_frame,
AVFormatContext* format,
AVStream* stream,
AVCodecContext* codec,
AVPacket* packet) {
int ret = filter->add_frame(src_frame);
int ret = filter->add_frame(src);
while (ret >= 0) {
ret = filter->get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
_encode(nullptr, format, stream, codec, packet);
encoder.encode(nullptr);
}
break;
}
if (ret >= 0) {
_encode(dst_frame, format, stream, codec, packet);
encoder.encode(dst_frame);
}
av_frame_unref(dst_frame);
}
}
} // namespace
void OutputStream::process_frame(AVFrame* src) {
if (filter) {
_process(src, filter, dst_frame, format_ctx, stream, codec_ctx, packet);
} else {
_encode(src, format_ctx, stream, codec_ctx, packet);
}
}
void OutputStream::flush() {
process_frame(nullptr);
}
......
......@@ -3,16 +3,15 @@
#include <torch/types.h>
#include <torchaudio/csrc/ffmpeg/ffmpeg.h>
#include <torchaudio/csrc/ffmpeg/filter_graph.h>
#include <torchaudio/csrc/ffmpeg/stream_writer/encoder.h>
namespace torchaudio::io {
struct OutputStream {
// Reference to the AVFormatContext that this stream belongs to
AVFormatContext* format_ctx;
// Stream object that OutputStream is responsible for managing
AVStream* stream;
// Codec context (encoder)
// Codec context
AVCodecContextPtr codec_ctx;
// Encoder + Muxer
Encoder encoder;
// Filter for additional processing
std::unique_ptr<FilterGraph> filter;
// frame that user-provided input data is written
......@@ -21,12 +20,9 @@ struct OutputStream {
AVFramePtr dst_frame;
// The number of samples written so far
int64_t num_frames;
// Temporary object used during the encoding
AVPacketPtr packet;
OutputStream(
AVFormatContext* format_ctx,
AVStream* stream,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame);
......@@ -42,7 +38,6 @@ struct AudioOutputStream : OutputStream {
AudioOutputStream(
AVFormatContext* format_ctx,
AVStream* stream,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
......@@ -59,7 +54,6 @@ struct VideoOutputStream : OutputStream {
VideoOutputStream(
AVFormatContext* format_ctx,
AVStream* stream,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
......
......@@ -442,14 +442,12 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
enum AVSampleFormat src_fmt = _get_src_sample_fmt(format);
AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_AUDIO, pFormatContext->oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
open_codec(ctx, encoder_option);
AVStream* stream = add_stream(ctx);
enum AVSampleFormat src_fmt = _get_src_sample_fmt(format);
std::unique_ptr<FilterGraph> filter = src_fmt == ctx->sample_fmt
? std::unique_ptr<FilterGraph>(nullptr)
: _get_audio_filter(src_fmt, ctx);
......@@ -458,7 +456,6 @@ void StreamWriter::add_audio_stream(
AVFramePtr src_frame = get_audio_frame(src_fmt, ctx, frame_capacity);
streams.emplace_back(std::make_unique<AudioOutputStream>(
pFormatContext,
stream,
std::move(ctx),
std::move(filter),
std::move(src_frame),
......@@ -491,7 +488,6 @@ void StreamWriter::add_video_stream(
"torchaudio is not compiled with CUDA support. Hardware acceleration is not available.");
#endif
}();
enum AVPixelFormat src_fmt = _get_src_pixel_fmt(format);
AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_VIDEO, pFormatContext->oformat, encoder);
......@@ -540,8 +536,8 @@ void StreamWriter::add_video_stream(
#endif
open_codec(ctx, encoder_option);
AVStream* stream = add_stream(ctx);
enum AVPixelFormat src_fmt = _get_src_pixel_fmt(format);
std::unique_ptr<FilterGraph> filter = [&]() {
if (src_fmt != ctx->pix_fmt && device.type() == c10::DeviceType::CPU) {
return _get_video_filter(src_fmt, ctx);
......@@ -557,7 +553,6 @@ void StreamWriter::add_video_stream(
}();
streams.emplace_back(std::make_unique<VideoOutputStream>(
pFormatContext,
stream,
std::move(ctx),
std::move(filter),
std::move(src_frame),
......@@ -565,20 +560,6 @@ void StreamWriter::add_video_stream(
std::move(hw_frame_ctx)));
}
AVStream* StreamWriter::add_stream(AVCodecContextPtr& codec_ctx) {
AVStream* stream = avformat_new_stream(pFormatContext, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
av_dict_free(&pFormatContext->metadata);
for (auto const& [key, value] : metadata) {
......@@ -653,7 +634,7 @@ void StreamWriter::validate_stream(int i, enum AVMediaType type) {
i);
TORCH_CHECK(
streams[i]->stream->codecpar->codec_type == type,
streams[i]->codec_ctx->codec_type == type,
"Stream ",
i,
" is not ",
......
......@@ -146,9 +146,6 @@ class StreamWriter {
/// @param metadata metadata.
void set_metadata(const OptionDict& metadata);
private:
AVStream* add_stream(AVCodecContextPtr& ctx);
//////////////////////////////////////////////////////////////////////////////
// Write methods
//////////////////////////////////////////////////////////////////////////////
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment