Commit 4eac61a3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor the initialization of EncodeProcess (#3205)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3205

This commit refactors the initialization of EncodeProcess.

Interface-wise, the signature of the constructor of EncodeProcess
has made simpler just to take rvalues of its components, and the
initialization of the components have been moved to helper functions.

Implementat-wise, the order that the components are initialized is
revised, and the source of initialization parameters is also revised.

For example, the original implementation first creates AVCodecContext,
and passes it around to create the other components. This relied on
an assumption that parameters AVCodecContext has (such as image size
and sample rate) are same as the source data. This is not always right,
and as we will introduce custom filter graph and allow on-the-fly
transform of rates and dimensions, it will become even less correct.

The new initialization constructs source AVFrame, TensorConverter and
FilterGraph from source attributes. This makes it easy to introduce
on-the-fly transform.

Reviewed By: nateanl

Differential Revision: D44360650

fbshipit-source-id: bf0e77dc1a5a40fc8e9870c50d07339d812762e8
parent d8a37a21
......@@ -2,180 +2,151 @@
namespace torchaudio::io {
namespace {
AVCodecContextPtr get_codec_ctx(
enum AVMediaType type,
AVFORMAT_CONST AVOutputFormat* oformat,
const c10::optional<std::string>& encoder) {
enum AVCodecID default_codec = [&]() {
switch (type) {
case AVMEDIA_TYPE_AUDIO:
return oformat->audio_codec;
case AVMEDIA_TYPE_VIDEO:
return oformat->video_codec;
default:
TORCH_CHECK(
false, "Unsupported media type: ", av_get_media_type_string(type));
}
}();
////////////////////////////////////////////////////////////////////////////////
// EncodeProcess Logic Implementation
////////////////////////////////////////////////////////////////////////////////
TORCH_CHECK(
default_codec != AV_CODEC_ID_NONE,
"Format \"",
oformat->name,
"\" does not support ",
av_get_media_type_string(type),
".");
EncodeProcess::EncodeProcess(
TensorConverter&& converter,
AVFramePtr&& frame,
FilterGraph&& filter_graph,
Encoder&& encoder,
AVCodecContextPtr&& codec_ctx) noexcept
: converter(std::move(converter)),
src_frame(std::move(frame)),
filter(std::move(filter_graph)),
encoder(std::move(encoder)),
codec_ctx(std::move(codec_ctx)) {}
const AVCodec* codec = [&]() {
if (encoder) {
const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str());
TORCH_CHECK(c, "Unexpected codec: ", encoder.value());
return c;
void EncodeProcess::process(
const torch::Tensor& tensor,
const c10::optional<double>& pts) {
if (pts) {
AVRational tb = codec_ctx->time_base;
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
if (src_frame->pts > val) {
TORCH_WARN_ONCE(
"The provided PTS value is smaller than the next expected value.");
}
const AVCodec* c = avcodec_find_encoder(default_codec);
TORCH_CHECK(
c, "Encoder not found for codec: ", avcodec_get_name(default_codec));
return c;
}();
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (oformat->flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
src_frame->pts = val;
}
for (const auto& frame : converter.convert(tensor)) {
process_frame(frame);
frame->pts += frame->nb_samples;
}
return AVCodecContextPtr(ctx);
}
std::vector<int> get_supported_sample_rates(const AVCodec* codec) {
std::vector<int> ret;
if (codec->supported_samplerates) {
const int* t = codec->supported_samplerates;
while (*t) {
ret.push_back(*t);
++t;
void EncodeProcess::process_frame(AVFrame* src) {
int ret = filter.add_frame(src);
while (ret >= 0) {
ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encoder.encode(nullptr);
}
break;
}
if (ret >= 0) {
encoder.encode(dst_frame);
}
av_frame_unref(dst_frame);
}
return ret;
}
std::vector<std::string> get_supported_sample_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->sample_fmts) {
const enum AVSampleFormat* t = codec->sample_fmts;
while (*t != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*t));
++t;
}
}
return ret;
void EncodeProcess::flush() {
process_frame(nullptr);
}
std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) {
std::vector<uint64_t> ret;
if (codec->channel_layouts) {
const uint64_t* t = codec->channel_layouts;
while (*t) {
ret.push_back(*t);
++t;
}
////////////////////////////////////////////////////////////////////////////////
// EncodeProcess Initialization helper functions
////////////////////////////////////////////////////////////////////////////////
namespace {
enum AVSampleFormat get_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
if (fmt != AV_SAMPLE_FMT_NONE && !av_sample_fmt_is_planar(fmt)) {
return fmt;
}
return ret;
TORCH_CHECK(
false,
"Unsupported sample fotmat (",
src,
") was provided. Valid values are ",
[]() -> std::string {
std::vector<std::string> ret;
for (const auto& fmt :
{AV_SAMPLE_FMT_U8,
AV_SAMPLE_FMT_S16,
AV_SAMPLE_FMT_S32,
AV_SAMPLE_FMT_S64,
AV_SAMPLE_FMT_FLT,
AV_SAMPLE_FMT_DBL}) {
ret.emplace_back(av_get_sample_fmt_name(fmt));
}
return c10::Join(", ", ret);
}(),
".");
}
void configure_audio_codec(
AVCodecContextPtr& ctx,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& format,
const c10::optional<EncodingConfig>& config) {
// TODO: Review options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122
// - bit_rate
// - bit_rate_tolerance
ctx->sample_rate = [&]() -> int {
auto rates = get_supported_sample_rates(ctx->codec);
if (rates.empty()) {
return static_cast<int>(sample_rate);
}
for (const auto& it : rates) {
if (it == sample_rate) {
return static_cast<int>(sample_rate);
}
}
TORCH_CHECK(
false,
ctx->codec->name,
" does not support sample rate ",
sample_rate,
". Supported sample rates are: ",
c10::Join(", ", rates));
}();
ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
ctx->sample_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->sample_fmts,
ctx->codec->name,
" does not have default sample format. Please specify one.");
return ctx->codec->sample_fmts[0];
}
// Use the given one.
auto fmt = format.value();
auto ret = av_get_sample_fmt(fmt.c_str());
auto fmts = get_supported_sample_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(
ret != AV_SAMPLE_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
TORCH_CHECK(
std::count(fmts.begin(), fmts.end(), fmt),
"Unsupported sample format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
return ret;
}();
enum AVPixelFormat get_pix_fmt(const std::string& src) {
AVPixelFormat fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
default:;
}
TORCH_CHECK(
false,
"Unsupported pixel format (",
src,
") was provided. Valid values are ",
[]() -> std::string {
std::vector<std::string> ret;
for (const auto& fmt :
{AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P}) {
ret.emplace_back(av_get_pix_fmt_name(fmt));
}
return c10::Join(", ", ret);
}(),
".");
}
// validate and set channels
ctx->channels = static_cast<int>(num_channels);
auto layout = av_get_default_channel_layout(ctx->channels);
auto layouts = get_supported_channel_layouts(ctx->codec);
if (!layouts.empty()) {
if (!std::count(layouts.begin(), layouts.end(), layout)) {
std::vector<std::string> tmp;
for (const auto& it : layouts) {
tmp.push_back(std::to_string(av_get_channel_layout_nb_channels(it)));
}
TORCH_CHECK(
false,
"Unsupported channels: ",
num_channels,
". Supported channels are: ",
c10::Join(", ", tmp));
}
////////////////////////////////////////////////////////////////////////////////
// Codec & Codec context
////////////////////////////////////////////////////////////////////////////////
const AVCodec* get_codec(
AVCodecID default_codec,
const c10::optional<std::string>& encoder) {
if (encoder) {
const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str());
TORCH_CHECK(c, "Unexpected codec: ", encoder.value());
return c;
}
ctx->channel_layout = static_cast<uint64_t>(layout);
const AVCodec* c = avcodec_find_encoder(default_codec);
TORCH_CHECK(
c, "Encoder not found for codec: ", avcodec_get_name(default_codec));
return c;
}
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate;
}
if (cfg.compression_level != -1) {
ctx->compression_level = cfg.compression_level;
}
AVCodecContextPtr get_codec_ctx(const AVCodec* codec, int flags) {
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
}
return AVCodecContextPtr(ctx);
}
void open_codec(
AVCodecContextPtr& codec_ctx,
AVCodecContext* codec_ctx,
const c10::optional<OptionDict>& option) {
AVDictionary* opt = get_option_dict(option);
......@@ -214,186 +185,242 @@ void open_codec(
TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")");
}
AVCodecContextPtr get_audio_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format, config);
open_codec(ctx, encoder_option);
return ctx;
////////////////////////////////////////////////////////////////////////////////
// Audio codec
////////////////////////////////////////////////////////////////////////////////
bool supported_sample_fmt(
const AVSampleFormat fmt,
const AVSampleFormat* sample_fmts) {
if (!sample_fmts) {
return true;
}
while (*sample_fmts != AV_SAMPLE_FMT_NONE) {
if (fmt == *sample_fmts) {
return true;
}
++sample_fmts;
}
return false;
}
std::vector<std::string> get_supported_formats(
const AVSampleFormat* sample_fmts) {
std::vector<std::string> ret;
while (*sample_fmts != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*sample_fmts));
++sample_fmts;
}
return ret;
}
FilterGraph get_audio_filter(
AVSampleFormat get_enc_fmt(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->sample_fmt) {
if (!codec_ctx->frame_size) {
return "anull";
} else {
std::stringstream ss;
ss << "asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
return ss.str();
}
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
if (codec_ctx->frame_size) {
ss << ",asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
}
return ss.str();
const c10::optional<std::string>& encoder_format,
const AVCodec* codec) {
if (encoder_format) {
auto& enc_fmt_val = encoder_format.value();
auto fmt = av_get_sample_fmt(enc_fmt_val.c_str());
TORCH_CHECK(
fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", enc_fmt_val);
TORCH_CHECK(
supported_sample_fmt(fmt, codec->sample_fmts),
codec->name,
" does not support ",
encoder_format.value(),
" format. Supported values are; ",
c10::Join(", ", get_supported_formats(codec->sample_fmts)));
return fmt;
}
if (codec->sample_fmts) {
return codec->sample_fmts[0];
}
return src_fmt;
};
bool supported_sample_rate(
const int sample_rate,
const int* supported_samplerates) {
if (!supported_samplerates) {
return true;
}
while (*supported_samplerates) {
if (sample_rate == *supported_samplerates) {
return true;
}
}();
++supported_samplerates;
}
return false;
}
FilterGraph p{AVMEDIA_TYPE_AUDIO};
p.add_audio_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->sample_rate,
codec_ctx->channel_layout);
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
std::vector<int> get_supported_samplerates(const int* supported_samplerates) {
std::vector<int> ret;
if (supported_samplerates) {
while (*supported_samplerates) {
ret.push_back(*supported_samplerates);
++supported_samplerates;
}
}
return ret;
}
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
void validate_sample_rate(int sample_rate, const AVCodec* codec) {
TORCH_CHECK(
supported_sample_rate(sample_rate, codec->supported_samplerates),
codec->name,
" does not support sample rate ",
sample_rate,
". Supported values are; ",
c10::Join(", ", get_supported_samplerates(codec->supported_samplerates)));
}
std::vector<std::string> get_supported_channels(
const uint64_t* channel_layouts) {
std::vector<std::string> ret;
while (*channel_layouts) {
ret.emplace_back(av_get_channel_name(*channel_layouts));
++channel_layouts;
}
return ret;
}
uint64_t get_channel_layout(int num_channels, const AVCodec* codec) {
if (!codec->channel_layouts) {
return static_cast<uint64_t>(av_get_default_channel_layout(num_channels));
}
for (const uint64_t* it = codec->channel_layouts; *it; ++it) {
if (av_get_channel_layout_nb_channels(*it) == num_channels) {
return *it;
}
}
TORCH_CHECK(
false,
"Codec ",
codec->name,
" does not support a channel layout consists of ",
num_channels,
" channels. Supported values are: ",
c10::Join(", ", get_supported_channels(codec->channel_layouts)));
}
void configure_audio_codec_ctx(
AVCodecContext* codec_ctx,
AVSampleFormat format,
int sample_rate,
int num_channels,
AVCodecContext* codec_ctx,
int default_frame_size = 10000) {
AVFramePtr frame{};
frame->pts = 0;
frame->format = src_fmt;
// Note: `channels` attribute is not required for encoding, but
// TensorConverter refers to it
frame->channels = num_channels;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
uint64_t channel_layout,
const c10::optional<EncodingConfig>& config) {
codec_ctx->sample_fmt = format;
codec_ctx->sample_rate = sample_rate;
codec_ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
codec_ctx->channels = num_channels;
codec_ctx->channel_layout = channel_layout;
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (cfg.bit_rate > 0) {
codec_ctx->bit_rate = cfg.bit_rate;
}
if (cfg.compression_level != -1) {
codec_ctx->compression_level = cfg.compression_level;
}
}
return frame;
}
std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->pix_fmts) {
const enum AVPixelFormat* t = codec->pix_fmts;
while (*t != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*t));
++t;
////////////////////////////////////////////////////////////////////////////////
// Video codec
////////////////////////////////////////////////////////////////////////////////
bool supported_pix_fmt(const AVPixelFormat fmt, const AVPixelFormat* pix_fmts) {
if (!pix_fmts) {
return true;
}
while (*pix_fmts != AV_PIX_FMT_NONE) {
if (fmt == *pix_fmts) {
return true;
}
++pix_fmts;
}
return false;
}
std::vector<std::string> get_supported_formats(const AVPixelFormat* pix_fmts) {
std::vector<std::string> ret;
while (*pix_fmts != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*pix_fmts));
++pix_fmts;
}
return ret;
}
std::vector<AVRational> get_supported_frame_rates(const AVCodec* codec) {
std::vector<AVRational> ret;
if (codec->supported_framerates) {
const AVRational* t = codec->supported_framerates;
while (!(t->num == 0 && t->den == 0)) {
ret.push_back(*t);
++t;
AVPixelFormat get_enc_fmt(
AVPixelFormat src_fmt,
const c10::optional<std::string>& encoder_format,
const AVCodec* codec) {
if (encoder_format) {
auto fmt = get_pix_fmt(encoder_format.value());
TORCH_CHECK(
supported_pix_fmt(fmt, codec->pix_fmts),
codec->name,
" does not support ",
encoder_format.value(),
" format. Supported values are; ",
c10::Join(", ", get_supported_formats(codec->pix_fmts)));
return fmt;
}
if (codec->pix_fmts) {
return codec->pix_fmts[0];
}
return src_fmt;
}
bool supported_frame_rate(AVRational rate, const AVRational* rates) {
if (!rates) {
return true;
}
for (; !(rates->num == 0 && rates->den == 0); ++rates) {
if (av_cmp_q(rate, *rates) == 0) {
return true;
}
}
return ret;
return false;
}
// used to compare frame rate / sample rate.
// not a general purpose float comparison
bool is_rate_close(double rate, AVRational rational) {
double ref =
static_cast<double>(rational.num) / static_cast<double>(rational.den);
// frame rates / sample rates
static const double threshold = 0.001;
return fabs(rate - ref) < threshold;
void validate_frame_rate(AVRational rate, const AVCodec* codec) {
TORCH_CHECK(
supported_frame_rate(rate, codec->supported_framerates),
codec->name,
" does not support frame rate ",
c10::Join("/", std::array<int, 2>{rate.num, rate.den}),
". Supported values are; ",
[&]() {
std::vector<std::string> ret;
for (auto r = codec->supported_framerates;
!(r->num == 0 && r->den == 0);
++r) {
ret.push_back(c10::Join("/", std::array<int, 2>{r->num, r->den}));
}
return c10::Join(", ", ret);
}());
}
void configure_video_codec(
void configure_video_codec_ctx(
AVCodecContextPtr& ctx,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& format,
AVPixelFormat format,
AVRational frame_rate,
int width,
int height,
const c10::optional<EncodingConfig>& config) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate
// - bit_rate_tolerance
// - gop_size
// - max_b_frames
// - mb_decisions
ctx->width = static_cast<int>(width);
ctx->height = static_cast<int>(height);
ctx->time_base = [&]() {
AVRational ret = av_inv_q(av_d2q(frame_rate, 1 << 24));
auto rates = get_supported_frame_rates(ctx->codec);
// Codec does not have constraint on frame rate
if (rates.empty()) {
return ret;
}
// Codec has list of supported frame rate.
for (const auto& t : rates) {
if (is_rate_close(frame_rate, t)) {
return ret;
}
}
// Given one is not supported.
std::vector<std::string> tmp;
for (const auto& t : rates) {
tmp.emplace_back(
t.den == 1 ? std::to_string(t.num)
: std::to_string(t.num) + "/" + std::to_string(t.den));
}
TORCH_CHECK(
false,
"Unsupported frame rate: ",
frame_rate,
". Supported values are ",
c10::Join(", ", tmp));
}();
ctx->pix_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->pix_fmts,
ctx->codec->name,
" does not have defaut pixel format. Please specify one.");
return ctx->codec->pix_fmts[0];
}
// Use the given one,
auto fmt = format.value();
auto ret = av_get_pix_fmt(fmt.c_str());
auto fmts = get_supported_pix_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(ret != AV_PIX_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
if (!std::count(fmts.begin(), fmts.end(), fmt)) {
TORCH_CHECK(
false,
"Unsupported pixel format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
}
return ret;
}();
ctx->pix_fmt = format;
ctx->width = width;
ctx->height = height;
ctx->time_base = av_inv_q(frame_rate);
// Set optional stuff
if (config) {
......@@ -416,9 +443,9 @@ void configure_video_codec(
void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
torch::Device device{hw_accel};
TORCH_CHECK(
device.type() == c10::DeviceType::CUDA,
device.is_cuda(),
"Only CUDA is supported for hardware acceleration. Found: ",
device.str());
device);
// NOTES:
// 1. Examples like
......@@ -463,77 +490,118 @@ void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
av_err2string(ret));
}
AVCodecContextPtr get_video_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format, config);
////////////////////////////////////////////////////////////////////////////////
// AVStream
////////////////////////////////////////////////////////////////////////////////
if (hw_accel) {
#ifdef USE_CUDA
configure_hw_accel(ctx, hw_accel.value());
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. ",
"Hardware acceleration is not available.");
#endif
}
AVStream* get_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
AVStream* stream = avformat_new_stream(format_ctx, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
open_codec(ctx, encoder_option);
return ctx;
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0, "Failed to copy the stream parameter: ", av_err2string(ret));
return stream;
}
////////////////////////////////////////////////////////////////////////////////
// FilterGraph
////////////////////////////////////////////////////////////////////////////////
FilterGraph get_audio_filter_graph(
AVSampleFormat src_fmt,
int sample_rate,
uint64_t channel_layout,
AVSampleFormat enc_fmt,
int nb_samples) {
const std::string filter_desc = [&]() -> const std::string {
if (src_fmt == enc_fmt) {
if (nb_samples == 0) {
return "anull";
} else {
std::stringstream ss;
ss << "asetnsamples=n=" << nb_samples << ":p=0";
return ss.str();
}
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(enc_fmt);
if (nb_samples > 0) {
ss << ",asetnsamples=n=" << nb_samples << ":p=0";
}
return ss.str();
}
}();
FilterGraph f{AVMEDIA_TYPE_AUDIO};
f.add_audio_src(src_fmt, {1, sample_rate}, sample_rate, channel_layout);
f.add_sink();
f.add_process(filter_desc);
f.create_filter();
return f;
}
FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
FilterGraph get_video_filter_graph(
AVPixelFormat src_fmt,
AVRational rate,
int width,
int height,
AVPixelFormat enc_fmt,
bool is_cuda) {
auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt ||
codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
if (src_fmt == enc_fmt || is_cuda) {
return "null";
} else {
std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
ss << "format=" << av_get_pix_fmt_name(enc_fmt);
return ss.str();
}
}();
FilterGraph p{AVMEDIA_TYPE_VIDEO};
p.add_video_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->framerate,
codec_ctx->width,
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
FilterGraph f{AVMEDIA_TYPE_VIDEO};
f.add_video_src(src_fmt, av_inv_q(rate), rate, width, height, {1, 1});
f.add_sink();
f.add_process(desc);
f.create_filter();
return f;
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
////////////////////////////////////////////////////////////////////////////////
// Source frame
////////////////////////////////////////////////////////////////////////////////
AVFramePtr get_audio_frame(
AVSampleFormat format,
int sample_rate,
int num_channels,
uint64_t channel_layout,
int nb_samples) {
AVFramePtr frame{};
if (codec_ctx->hw_frames_ctx) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
frame->format = format;
frame->channel_layout = channel_layout;
frame->sample_rate = sample_rate;
frame->nb_samples = nb_samples ? nb_samples : 1024;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating the source audio frame:", av_err2string(ret));
// Note: `channels` attribute is not required for encoding, but
// TensorConverter refers to it
frame->channels = num_channels;
frame->pts = 0;
return frame;
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, int width, int height) {
AVFramePtr frame{};
frame->format = src_fmt;
frame->width = width;
frame->height = height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating a video buffer :", av_err2string(ret));
// Note: `nb_samples` attribute is not used for video, but we set it
// anyways so that we can make the logic of PTS increment agnostic to
// audio and video.
......@@ -544,99 +612,164 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
} // namespace
EncodeProcess::EncodeProcess(
////////////////////////////////////////////////////////////////////////////////
// Finally, the extern-facing API
////////////////////////////////////////////////////////////////////////////////
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const enum AVSampleFormat format,
int src_sample_rate,
int src_num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config)
: codec_ctx(get_audio_codec(
format_ctx->oformat,
sample_rate,
num_channels,
encoder,
encoder_option,
encoder_format,
config)),
encoder(format_ctx, codec_ctx),
filter(get_audio_filter(format, codec_ctx)),
src_frame(get_audio_frame(format, sample_rate, num_channels, codec_ctx)),
converter(AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples) {}
const c10::optional<EncodingConfig>& config) {
// 1. Check the source format, rate and channels
const AVSampleFormat src_fmt = get_sample_fmt(format);
TORCH_CHECK(
src_sample_rate > 0,
"Sample rate must be positive. Found: ",
src_sample_rate);
TORCH_CHECK(
src_num_channels > 0,
"The number of channels must be positive. Found: ",
src_num_channels);
EncodeProcess::EncodeProcess(
// 2. Fetch codec from default or override
TORCH_CHECK(
format_ctx->oformat->audio_codec != AV_CODEC_ID_NONE,
format_ctx->oformat->name,
" does not support audio.");
const AVCodec* codec = get_codec(format_ctx->oformat->audio_codec, encoder);
// 3. Check that encoding sample format, sample rate and channels
// TODO: introduce encoder_sampel_rate option and allow to change sample rate
const AVSampleFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
validate_sample_rate(src_sample_rate, codec);
uint64_t channel_layout = get_channel_layout(src_num_channels, codec);
// 4. Initialize codec context
AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags);
configure_audio_codec_ctx(
codec_ctx,
enc_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
config);
open_codec(codec_ctx, encoder_option);
// 5. Build filter graph
FilterGraph filter_graph = get_audio_filter_graph(
src_fmt, src_sample_rate, channel_layout, enc_fmt, codec_ctx->frame_size);
// 6. Instantiate source frame
AVFramePtr src_frame = get_audio_frame(
src_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
codec_ctx->frame_size);
// 7. Instantiate Converter
TensorConverter converter{
AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
// If anything after this throws, it will leave the StreamWriter in an
// invalid state.
Encoder enc{format_ctx, codec_ctx, get_stream(format_ctx, codec_ctx)};
return EncodeProcess{
std::move(converter),
std::move(src_frame),
std::move(filter_graph),
std::move(enc),
std::move(codec_ctx)};
}
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const enum AVPixelFormat format,
int src_width,
int src_height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config)
: codec_ctx(get_video_codec(
format_ctx->oformat,
frame_rate,
width,
height,
encoder,
encoder_option,
encoder_format,
hw_accel,
config)),
encoder(format_ctx, codec_ctx),
filter(get_video_filter(format, codec_ctx)),
src_frame(get_video_frame(format, codec_ctx)),
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}
const c10::optional<EncodingConfig>& config) {
// 1. Checkc the source format, rate and resolution
const AVPixelFormat src_fmt = get_pix_fmt(format);
AVRational src_rate = av_d2q(frame_rate, 1 << 24);
TORCH_CHECK(
src_rate.num > 0 && src_rate.den != 0,
"Frame rate must be positive and finite. Found: ",
frame_rate);
TORCH_CHECK(src_width > 0, "width must be positive. Found: ", src_width);
TORCH_CHECK(src_height > 0, "height must be positive. Found: ", src_height);
void EncodeProcess::process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts) {
// 2. Fetch codec from default or override
TORCH_CHECK(
codec_ctx->codec_type == type,
"Attempted to write ",
av_get_media_type_string(type),
" to ",
av_get_media_type_string(codec_ctx->codec_type),
" stream.");
if (pts) {
AVRational tb = codec_ctx->time_base;
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
if (src_frame->pts > val) {
TORCH_WARN_ONCE(
"The provided PTS value is smaller than the next expected value.");
}
src_frame->pts = val;
}
for (const auto& frame : converter.convert(tensor)) {
process_frame(frame);
frame->pts += frame->nb_samples;
format_ctx->oformat->video_codec != AV_CODEC_ID_NONE,
format_ctx->oformat->name,
" does not support video.");
const AVCodec* codec = get_codec(format_ctx->oformat->video_codec, encoder);
// 3. Check that encoding format, rate
const AVPixelFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
validate_frame_rate(src_rate, codec);
// 4. Initialize codec context
AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags);
configure_video_codec_ctx(
codec_ctx, enc_fmt, src_rate, src_width, src_height, config);
if (hw_accel) {
#ifdef USE_CUDA
configure_hw_accel(codec_ctx, hw_accel.value());
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. ",
"Hardware acceleration is not available.");
#endif
}
}
void EncodeProcess::process_frame(AVFrame* src) {
int ret = filter.add_frame(src);
while (ret >= 0) {
ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
if (ret == AVERROR_EOF) {
encoder.encode(nullptr);
}
break;
open_codec(codec_ctx, encoder_option);
// 5. Build filter graph
FilterGraph filter_graph = get_video_filter_graph(
src_fmt, src_rate, src_width, src_height, enc_fmt, hw_accel.has_value());
// 6. Instantiate source frame
AVFramePtr src_frame = [&]() {
if (codec_ctx->hw_frames_ctx) {
AVFramePtr frame{};
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame;
}
if (ret >= 0) {
encoder.encode(dst_frame);
}
av_frame_unref(dst_frame);
}
}
return get_video_frame(src_fmt, src_width, src_height);
}();
void EncodeProcess::flush() {
process_frame(nullptr);
// 7. Converter
TensorConverter converter{AVMEDIA_TYPE_VIDEO, src_frame};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
// If anything after this throws, it will leave the StreamWriter in an
// invalid state.
Encoder enc{format_ctx, codec_ctx, get_stream(format_ctx, codec_ctx)};
return EncodeProcess{
std::move(converter),
std::move(src_frame),
std::move(filter_graph),
std::move(enc),
std::move(codec_ctx)};
}
} // namespace torchaudio::io
......@@ -9,47 +9,50 @@
namespace torchaudio::io {
class EncodeProcess {
// In the reverse order of the process
AVCodecContextPtr codec_ctx;
Encoder encoder;
AVFramePtr dst_frame{};
FilterGraph filter;
AVFramePtr src_frame;
TensorConverter converter;
AVFramePtr src_frame;
FilterGraph filter;
AVFramePtr dst_frame{};
Encoder encoder;
AVCodecContextPtr codec_ctx;
public:
// Constructor for audio
EncodeProcess(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const enum AVSampleFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
// constructor for video
EncodeProcess(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const enum AVPixelFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
void process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts);
TensorConverter&& converter,
AVFramePtr&& frame,
FilterGraph&& filter_graph,
Encoder&& encoder,
AVCodecContextPtr&& codec_ctx) noexcept;
EncodeProcess(EncodeProcess&&) noexcept = default;
void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
void process_frame(AVFrame* src);
void flush();
};
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
}; // namespace torchaudio::io
......@@ -2,28 +2,11 @@
namespace torchaudio::io {
namespace {
AVStream* add_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
AVStream* stream = avformat_new_stream(format_ctx, nullptr);
TORCH_CHECK(stream, "Failed to allocate stream.");
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
} // namespace
Encoder::Encoder(AVFormatContext* format_ctx_, AVCodecContext* codec_ctx_)
: format_ctx(format_ctx_),
codec_ctx(codec_ctx_),
stream(add_stream(format_ctx, codec_ctx)) {}
Encoder::Encoder(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
AVStream* stream) noexcept
: format_ctx(format_ctx), codec_ctx(codec_ctx), stream(stream) {}
///
/// Encode the given AVFrame data
......
......@@ -12,16 +12,17 @@ class Encoder {
AVFormatContext* format_ctx;
// Reference to codec context (encoder)
AVCodecContext* codec_ctx;
// Stream object
// Encoder object creates AVStream, but it will be deallocated along with
// AVFormatContext, So Encoder does not own it.
// Stream object as reference. Owned by AVFormatContext.
AVStream* stream;
// Temporary object used during the encoding
// Encoder owns it.
AVPacketPtr packet{};
public:
Encoder(AVFormatContext* format_ctx, AVCodecContext* codec_ctx);
Encoder(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
AVStream* stream) noexcept;
void encode(AVFrame* frame);
};
......
......@@ -53,88 +53,54 @@ StreamWriter::StreamWriter(
const c10::optional<std::string>& format)
: StreamWriter(get_output_format_context(dst, format, nullptr)) {}
namespace {
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
TORCH_CHECK(fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", src);
TORCH_CHECK(
!av_sample_fmt_is_planar(fmt),
"Unexpected sample fotmat value. Valid values are ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_U8),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S16),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S32),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S64),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_DBL),
". ",
"Found: ",
src);
return fmt;
}
enum AVPixelFormat get_src_pixel_fmt(const std::string& src) {
auto fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
case AV_PIX_FMT_NONE:
TORCH_CHECK(false, "Unknown pixel format: ", src);
default:
TORCH_CHECK(false, "Unsupported pixel format: ", src);
}
}
} // namespace
void StreamWriter::add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back(
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
pFormatContext,
sample_rate,
num_channels,
get_src_sample_fmt(format),
format,
encoder,
encoder_option,
encoder_format,
config);
config));
}
void StreamWriter::add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
processes.emplace_back(
TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
pFormatContext,
frame_rate,
width,
height,
get_src_pixel_fmt(format),
format,
encoder,
encoder_option,
encoder_format,
hw_accel,
config);
config));
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) {
}
void StreamWriter::open(const c10::optional<OptionDict>& option) {
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
int ret = 0;
// Open the file if it was not provided by client code (i.e. when not
......@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk(
const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ",
processes.size(),
pFormatContext->nb_streams,
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts);
TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO,
"Stream ",
i,
" is not audio type.");
processes[i].process(waveform, pts);
}
void StreamWriter::write_video_chunk(
......@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk(
const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()),
0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ",
processes.size(),
pFormatContext->nb_streams,
"). Found: ",
i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts);
TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_VIDEO,
"Stream ",
i,
" is not video type.");
processes[i].process(frames, pts);
}
void StreamWriter::flush() {
......
......@@ -100,8 +100,8 @@ class StreamWriter {
/// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command.
void add_audio_stream(
int64_t sample_rate,
int64_t num_channels,
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
......@@ -139,8 +139,8 @@ class StreamWriter {
/// @endparblock
void add_video_stream(
double frame_rate,
int64_t width,
int64_t height,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment