Commit 4eac61a3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor the initialization of EncodeProcess (#3205)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3205

This commit refactors the initialization of EncodeProcess.

Interface-wise, the signature of the constructor of EncodeProcess
has made simpler just to take rvalues of its components, and the
initialization of the components have been moved to helper functions.

Implementat-wise, the order that the components are initialized is
revised, and the source of initialization parameters is also revised.

For example, the original implementation first creates AVCodecContext,
and passes it around to create the other components. This relied on
an assumption that parameters AVCodecContext has (such as image size
and sample rate) are same as the source data. This is not always right,
and as we will introduce custom filter graph and allow on-the-fly
transform of rates and dimensions, it will become even less correct.

The new initialization constructs source AVFrame, TensorConverter and
FilterGraph from source attributes. This makes it easy to introduce
on-the-fly transform.

Reviewed By: nateanl

Differential Revision: D44360650

fbshipit-source-id: bf0e77dc1a5a40fc8e9870c50d07339d812762e8
parent d8a37a21
...@@ -2,180 +2,151 @@ ...@@ -2,180 +2,151 @@
namespace torchaudio::io { namespace torchaudio::io {
namespace { ////////////////////////////////////////////////////////////////////////////////
// EncodeProcess Logic Implementation
AVCodecContextPtr get_codec_ctx( ////////////////////////////////////////////////////////////////////////////////
enum AVMediaType type,
AVFORMAT_CONST AVOutputFormat* oformat,
const c10::optional<std::string>& encoder) {
enum AVCodecID default_codec = [&]() {
switch (type) {
case AVMEDIA_TYPE_AUDIO:
return oformat->audio_codec;
case AVMEDIA_TYPE_VIDEO:
return oformat->video_codec;
default:
TORCH_CHECK(
false, "Unsupported media type: ", av_get_media_type_string(type));
}
}();
TORCH_CHECK( EncodeProcess::EncodeProcess(
default_codec != AV_CODEC_ID_NONE, TensorConverter&& converter,
"Format \"", AVFramePtr&& frame,
oformat->name, FilterGraph&& filter_graph,
"\" does not support ", Encoder&& encoder,
av_get_media_type_string(type), AVCodecContextPtr&& codec_ctx) noexcept
"."); : converter(std::move(converter)),
src_frame(std::move(frame)),
filter(std::move(filter_graph)),
encoder(std::move(encoder)),
codec_ctx(std::move(codec_ctx)) {}
const AVCodec* codec = [&]() { void EncodeProcess::process(
if (encoder) { const torch::Tensor& tensor,
const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str()); const c10::optional<double>& pts) {
TORCH_CHECK(c, "Unexpected codec: ", encoder.value()); if (pts) {
return c; AVRational tb = codec_ctx->time_base;
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
if (src_frame->pts > val) {
TORCH_WARN_ONCE(
"The provided PTS value is smaller than the next expected value.");
} }
const AVCodec* c = avcodec_find_encoder(default_codec); src_frame->pts = val;
TORCH_CHECK( }
c, "Encoder not found for codec: ", avcodec_get_name(default_codec)); for (const auto& frame : converter.convert(tensor)) {
return c; process_frame(frame);
}(); frame->pts += frame->nb_samples;
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (oformat->flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
} }
return AVCodecContextPtr(ctx);
} }
std::vector<int> get_supported_sample_rates(const AVCodec* codec) { void EncodeProcess::process_frame(AVFrame* src) {
std::vector<int> ret; int ret = filter.add_frame(src);
if (codec->supported_samplerates) { while (ret >= 0) {
const int* t = codec->supported_samplerates; ret = filter.get_frame(dst_frame);
while (*t) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
ret.push_back(*t); if (ret == AVERROR_EOF) {
++t; encoder.encode(nullptr);
}
break;
} }
if (ret >= 0) {
encoder.encode(dst_frame);
}
av_frame_unref(dst_frame);
} }
return ret;
} }
std::vector<std::string> get_supported_sample_fmts(const AVCodec* codec) { void EncodeProcess::flush() {
std::vector<std::string> ret; process_frame(nullptr);
if (codec->sample_fmts) {
const enum AVSampleFormat* t = codec->sample_fmts;
while (*t != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*t));
++t;
}
}
return ret;
} }
std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) { ////////////////////////////////////////////////////////////////////////////////
std::vector<uint64_t> ret; // EncodeProcess Initialization helper functions
if (codec->channel_layouts) { ////////////////////////////////////////////////////////////////////////////////
const uint64_t* t = codec->channel_layouts;
while (*t) { namespace {
ret.push_back(*t);
++t; enum AVSampleFormat get_sample_fmt(const std::string& src) {
} auto fmt = av_get_sample_fmt(src.c_str());
if (fmt != AV_SAMPLE_FMT_NONE && !av_sample_fmt_is_planar(fmt)) {
return fmt;
} }
return ret; TORCH_CHECK(
false,
"Unsupported sample fotmat (",
src,
") was provided. Valid values are ",
[]() -> std::string {
std::vector<std::string> ret;
for (const auto& fmt :
{AV_SAMPLE_FMT_U8,
AV_SAMPLE_FMT_S16,
AV_SAMPLE_FMT_S32,
AV_SAMPLE_FMT_S64,
AV_SAMPLE_FMT_FLT,
AV_SAMPLE_FMT_DBL}) {
ret.emplace_back(av_get_sample_fmt_name(fmt));
}
return c10::Join(", ", ret);
}(),
".");
} }
void configure_audio_codec( enum AVPixelFormat get_pix_fmt(const std::string& src) {
AVCodecContextPtr& ctx, AVPixelFormat fmt = av_get_pix_fmt(src.c_str());
int64_t sample_rate, switch (fmt) {
int64_t num_channels, case AV_PIX_FMT_GRAY8:
const c10::optional<std::string>& format, case AV_PIX_FMT_RGB24:
const c10::optional<EncodingConfig>& config) { case AV_PIX_FMT_BGR24:
// TODO: Review options and make them configurable? case AV_PIX_FMT_YUV444P:
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122 return fmt;
// - bit_rate default:;
// - bit_rate_tolerance }
TORCH_CHECK(
ctx->sample_rate = [&]() -> int { false,
auto rates = get_supported_sample_rates(ctx->codec); "Unsupported pixel format (",
if (rates.empty()) { src,
return static_cast<int>(sample_rate); ") was provided. Valid values are ",
} []() -> std::string {
for (const auto& it : rates) { std::vector<std::string> ret;
if (it == sample_rate) { for (const auto& fmt :
return static_cast<int>(sample_rate); {AV_PIX_FMT_GRAY8,
} AV_PIX_FMT_RGB24,
} AV_PIX_FMT_BGR24,
TORCH_CHECK( AV_PIX_FMT_YUV444P}) {
false, ret.emplace_back(av_get_pix_fmt_name(fmt));
ctx->codec->name, }
" does not support sample rate ", return c10::Join(", ", ret);
sample_rate, }(),
". Supported sample rates are: ", ".");
c10::Join(", ", rates)); }
}();
ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
ctx->sample_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->sample_fmts,
ctx->codec->name,
" does not have default sample format. Please specify one.");
return ctx->codec->sample_fmts[0];
}
// Use the given one.
auto fmt = format.value();
auto ret = av_get_sample_fmt(fmt.c_str());
auto fmts = get_supported_sample_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(
ret != AV_SAMPLE_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
TORCH_CHECK(
std::count(fmts.begin(), fmts.end(), fmt),
"Unsupported sample format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
return ret;
}();
// validate and set channels ////////////////////////////////////////////////////////////////////////////////
ctx->channels = static_cast<int>(num_channels); // Codec & Codec context
auto layout = av_get_default_channel_layout(ctx->channels); ////////////////////////////////////////////////////////////////////////////////
auto layouts = get_supported_channel_layouts(ctx->codec); const AVCodec* get_codec(
if (!layouts.empty()) { AVCodecID default_codec,
if (!std::count(layouts.begin(), layouts.end(), layout)) { const c10::optional<std::string>& encoder) {
std::vector<std::string> tmp; if (encoder) {
for (const auto& it : layouts) { const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str());
tmp.push_back(std::to_string(av_get_channel_layout_nb_channels(it))); TORCH_CHECK(c, "Unexpected codec: ", encoder.value());
} return c;
TORCH_CHECK(
false,
"Unsupported channels: ",
num_channels,
". Supported channels are: ",
c10::Join(", ", tmp));
}
} }
ctx->channel_layout = static_cast<uint64_t>(layout); const AVCodec* c = avcodec_find_encoder(default_codec);
TORCH_CHECK(
c, "Encoder not found for codec: ", avcodec_get_name(default_codec));
return c;
}
// Set optional stuff AVCodecContextPtr get_codec_ctx(const AVCodec* codec, int flags) {
if (config) { AVCodecContext* ctx = avcodec_alloc_context3(codec);
auto& cfg = config.value(); TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (cfg.bit_rate > 0) {
ctx->bit_rate = cfg.bit_rate; if (flags & AVFMT_GLOBALHEADER) {
} ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
if (cfg.compression_level != -1) {
ctx->compression_level = cfg.compression_level;
}
} }
return AVCodecContextPtr(ctx);
} }
void open_codec( void open_codec(
AVCodecContextPtr& codec_ctx, AVCodecContext* codec_ctx,
const c10::optional<OptionDict>& option) { const c10::optional<OptionDict>& option) {
AVDictionary* opt = get_option_dict(option); AVDictionary* opt = get_option_dict(option);
...@@ -214,186 +185,242 @@ void open_codec( ...@@ -214,186 +185,242 @@ void open_codec(
TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")"); TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")");
} }
AVCodecContextPtr get_audio_codec( ////////////////////////////////////////////////////////////////////////////////
AVFORMAT_CONST AVOutputFormat* oformat, // Audio codec
int64_t sample_rate, ////////////////////////////////////////////////////////////////////////////////
int64_t num_channels,
const c10::optional<std::string>& encoder, bool supported_sample_fmt(
const c10::optional<OptionDict>& encoder_option, const AVSampleFormat fmt,
const c10::optional<std::string>& encoder_format, const AVSampleFormat* sample_fmts) {
const c10::optional<EncodingConfig>& config) { if (!sample_fmts) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder); return true;
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format, config); }
open_codec(ctx, encoder_option); while (*sample_fmts != AV_SAMPLE_FMT_NONE) {
return ctx; if (fmt == *sample_fmts) {
return true;
}
++sample_fmts;
}
return false;
}
std::vector<std::string> get_supported_formats(
const AVSampleFormat* sample_fmts) {
std::vector<std::string> ret;
while (*sample_fmts != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*sample_fmts));
++sample_fmts;
}
return ret;
} }
FilterGraph get_audio_filter( AVSampleFormat get_enc_fmt(
AVSampleFormat src_fmt, AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) { const c10::optional<std::string>& encoder_format,
auto desc = [&]() -> std::string { const AVCodec* codec) {
if (src_fmt == codec_ctx->sample_fmt) { if (encoder_format) {
if (!codec_ctx->frame_size) { auto& enc_fmt_val = encoder_format.value();
return "anull"; auto fmt = av_get_sample_fmt(enc_fmt_val.c_str());
} else { TORCH_CHECK(
std::stringstream ss; fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", enc_fmt_val);
ss << "asetnsamples=n=" << codec_ctx->frame_size << ":p=0"; TORCH_CHECK(
return ss.str(); supported_sample_fmt(fmt, codec->sample_fmts),
} codec->name,
} else { " does not support ",
std::stringstream ss; encoder_format.value(),
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt); " format. Supported values are; ",
if (codec_ctx->frame_size) { c10::Join(", ", get_supported_formats(codec->sample_fmts)));
ss << ",asetnsamples=n=" << codec_ctx->frame_size << ":p=0"; return fmt;
} }
return ss.str(); if (codec->sample_fmts) {
return codec->sample_fmts[0];
}
return src_fmt;
};
bool supported_sample_rate(
const int sample_rate,
const int* supported_samplerates) {
if (!supported_samplerates) {
return true;
}
while (*supported_samplerates) {
if (sample_rate == *supported_samplerates) {
return true;
} }
}(); ++supported_samplerates;
}
return false;
}
FilterGraph p{AVMEDIA_TYPE_AUDIO}; std::vector<int> get_supported_samplerates(const int* supported_samplerates) {
p.add_audio_src( std::vector<int> ret;
src_fmt, if (supported_samplerates) {
codec_ctx->time_base, while (*supported_samplerates) {
codec_ctx->sample_rate, ret.push_back(*supported_samplerates);
codec_ctx->channel_layout); ++supported_samplerates;
p.add_sink(); }
p.add_process(desc); }
p.create_filter(); return ret;
return p;
} }
AVFramePtr get_audio_frame( void validate_sample_rate(int sample_rate, const AVCodec* codec) {
AVSampleFormat src_fmt, TORCH_CHECK(
supported_sample_rate(sample_rate, codec->supported_samplerates),
codec->name,
" does not support sample rate ",
sample_rate,
". Supported values are; ",
c10::Join(", ", get_supported_samplerates(codec->supported_samplerates)));
}
std::vector<std::string> get_supported_channels(
const uint64_t* channel_layouts) {
std::vector<std::string> ret;
while (*channel_layouts) {
ret.emplace_back(av_get_channel_name(*channel_layouts));
++channel_layouts;
}
return ret;
}
uint64_t get_channel_layout(int num_channels, const AVCodec* codec) {
if (!codec->channel_layouts) {
return static_cast<uint64_t>(av_get_default_channel_layout(num_channels));
}
for (const uint64_t* it = codec->channel_layouts; *it; ++it) {
if (av_get_channel_layout_nb_channels(*it) == num_channels) {
return *it;
}
}
TORCH_CHECK(
false,
"Codec ",
codec->name,
" does not support a channel layout consists of ",
num_channels,
" channels. Supported values are: ",
c10::Join(", ", get_supported_channels(codec->channel_layouts)));
}
void configure_audio_codec_ctx(
AVCodecContext* codec_ctx,
AVSampleFormat format,
int sample_rate, int sample_rate,
int num_channels, int num_channels,
AVCodecContext* codec_ctx, uint64_t channel_layout,
int default_frame_size = 10000) { const c10::optional<EncodingConfig>& config) {
AVFramePtr frame{}; codec_ctx->sample_fmt = format;
frame->pts = 0; codec_ctx->sample_rate = sample_rate;
frame->format = src_fmt; codec_ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
// Note: `channels` attribute is not required for encoding, but codec_ctx->channels = num_channels;
// TensorConverter refers to it codec_ctx->channel_layout = channel_layout;
frame->channels = num_channels;
frame->channel_layout = codec_ctx->channel_layout; // Set optional stuff
frame->sample_rate = sample_rate; if (config) {
frame->nb_samples = auto& cfg = config.value();
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size; if (cfg.bit_rate > 0) {
if (frame->nb_samples) { codec_ctx->bit_rate = cfg.bit_rate;
int ret = av_frame_get_buffer(frame, 0); }
TORCH_CHECK( if (cfg.compression_level != -1) {
ret >= 0, codec_ctx->compression_level = cfg.compression_level;
"Error allocating an audio buffer (", }
av_err2string(ret),
").");
} }
return frame;
} }
std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) { ////////////////////////////////////////////////////////////////////////////////
std::vector<std::string> ret; // Video codec
if (codec->pix_fmts) { ////////////////////////////////////////////////////////////////////////////////
const enum AVPixelFormat* t = codec->pix_fmts;
while (*t != AV_PIX_FMT_NONE) { bool supported_pix_fmt(const AVPixelFormat fmt, const AVPixelFormat* pix_fmts) {
ret.emplace_back(av_get_pix_fmt_name(*t)); if (!pix_fmts) {
++t; return true;
}
while (*pix_fmts != AV_PIX_FMT_NONE) {
if (fmt == *pix_fmts) {
return true;
} }
++pix_fmts;
}
return false;
}
std::vector<std::string> get_supported_formats(const AVPixelFormat* pix_fmts) {
std::vector<std::string> ret;
while (*pix_fmts != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*pix_fmts));
++pix_fmts;
} }
return ret; return ret;
} }
std::vector<AVRational> get_supported_frame_rates(const AVCodec* codec) { AVPixelFormat get_enc_fmt(
std::vector<AVRational> ret; AVPixelFormat src_fmt,
if (codec->supported_framerates) { const c10::optional<std::string>& encoder_format,
const AVRational* t = codec->supported_framerates; const AVCodec* codec) {
while (!(t->num == 0 && t->den == 0)) { if (encoder_format) {
ret.push_back(*t); auto fmt = get_pix_fmt(encoder_format.value());
++t; TORCH_CHECK(
supported_pix_fmt(fmt, codec->pix_fmts),
codec->name,
" does not support ",
encoder_format.value(),
" format. Supported values are; ",
c10::Join(", ", get_supported_formats(codec->pix_fmts)));
return fmt;
}
if (codec->pix_fmts) {
return codec->pix_fmts[0];
}
return src_fmt;
}
bool supported_frame_rate(AVRational rate, const AVRational* rates) {
if (!rates) {
return true;
}
for (; !(rates->num == 0 && rates->den == 0); ++rates) {
if (av_cmp_q(rate, *rates) == 0) {
return true;
} }
} }
return ret; return false;
} }
// used to compare frame rate / sample rate. void validate_frame_rate(AVRational rate, const AVCodec* codec) {
// not a general purpose float comparison TORCH_CHECK(
bool is_rate_close(double rate, AVRational rational) { supported_frame_rate(rate, codec->supported_framerates),
double ref = codec->name,
static_cast<double>(rational.num) / static_cast<double>(rational.den); " does not support frame rate ",
// frame rates / sample rates c10::Join("/", std::array<int, 2>{rate.num, rate.den}),
static const double threshold = 0.001; ". Supported values are; ",
return fabs(rate - ref) < threshold; [&]() {
std::vector<std::string> ret;
for (auto r = codec->supported_framerates;
!(r->num == 0 && r->den == 0);
++r) {
ret.push_back(c10::Join("/", std::array<int, 2>{r->num, r->den}));
}
return c10::Join(", ", ret);
}());
} }
void configure_video_codec( void configure_video_codec_ctx(
AVCodecContextPtr& ctx, AVCodecContextPtr& ctx,
double frame_rate, AVPixelFormat format,
int64_t width, AVRational frame_rate,
int64_t height, int width,
const c10::optional<std::string>& format, int height,
const c10::optional<EncodingConfig>& config) { const c10::optional<EncodingConfig>& config) {
// TODO: Review other options and make them configurable? // TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147 // https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate
// - bit_rate_tolerance // - bit_rate_tolerance
// - gop_size
// - max_b_frames
// - mb_decisions // - mb_decisions
ctx->width = static_cast<int>(width); ctx->pix_fmt = format;
ctx->height = static_cast<int>(height); ctx->width = width;
ctx->time_base = [&]() { ctx->height = height;
AVRational ret = av_inv_q(av_d2q(frame_rate, 1 << 24)); ctx->time_base = av_inv_q(frame_rate);
auto rates = get_supported_frame_rates(ctx->codec);
// Codec does not have constraint on frame rate
if (rates.empty()) {
return ret;
}
// Codec has list of supported frame rate.
for (const auto& t : rates) {
if (is_rate_close(frame_rate, t)) {
return ret;
}
}
// Given one is not supported.
std::vector<std::string> tmp;
for (const auto& t : rates) {
tmp.emplace_back(
t.den == 1 ? std::to_string(t.num)
: std::to_string(t.num) + "/" + std::to_string(t.den));
}
TORCH_CHECK(
false,
"Unsupported frame rate: ",
frame_rate,
". Supported values are ",
c10::Join(", ", tmp));
}();
ctx->pix_fmt = [&]() {
// Use default
if (!format) {
TORCH_CHECK(
ctx->codec->pix_fmts,
ctx->codec->name,
" does not have defaut pixel format. Please specify one.");
return ctx->codec->pix_fmts[0];
}
// Use the given one,
auto fmt = format.value();
auto ret = av_get_pix_fmt(fmt.c_str());
auto fmts = get_supported_pix_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(ret != AV_PIX_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
if (!std::count(fmts.begin(), fmts.end(), fmt)) {
TORCH_CHECK(
false,
"Unsupported pixel format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
}
return ret;
}();
// Set optional stuff // Set optional stuff
if (config) { if (config) {
...@@ -416,9 +443,9 @@ void configure_video_codec( ...@@ -416,9 +443,9 @@ void configure_video_codec(
void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) { void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
torch::Device device{hw_accel}; torch::Device device{hw_accel};
TORCH_CHECK( TORCH_CHECK(
device.type() == c10::DeviceType::CUDA, device.is_cuda(),
"Only CUDA is supported for hardware acceleration. Found: ", "Only CUDA is supported for hardware acceleration. Found: ",
device.str()); device);
// NOTES: // NOTES:
// 1. Examples like // 1. Examples like
...@@ -463,77 +490,118 @@ void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) { ...@@ -463,77 +490,118 @@ void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
av_err2string(ret)); av_err2string(ret));
} }
AVCodecContextPtr get_video_codec( ////////////////////////////////////////////////////////////////////////////////
AVFORMAT_CONST AVOutputFormat* oformat, // AVStream
double frame_rate, ////////////////////////////////////////////////////////////////////////////////
int64_t width,
int64_t height,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format, config);
if (hw_accel) { AVStream* get_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
#ifdef USE_CUDA AVStream* stream = avformat_new_stream(format_ctx, nullptr);
configure_hw_accel(ctx, hw_accel.value()); TORCH_CHECK(stream, "Failed to allocate stream.");
#else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. ",
"Hardware acceleration is not available.");
#endif
}
open_codec(ctx, encoder_option); stream->time_base = codec_ctx->time_base;
return ctx; int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0, "Failed to copy the stream parameter: ", av_err2string(ret));
return stream;
}
////////////////////////////////////////////////////////////////////////////////
// FilterGraph
////////////////////////////////////////////////////////////////////////////////
FilterGraph get_audio_filter_graph(
AVSampleFormat src_fmt,
int sample_rate,
uint64_t channel_layout,
AVSampleFormat enc_fmt,
int nb_samples) {
const std::string filter_desc = [&]() -> const std::string {
if (src_fmt == enc_fmt) {
if (nb_samples == 0) {
return "anull";
} else {
std::stringstream ss;
ss << "asetnsamples=n=" << nb_samples << ":p=0";
return ss.str();
}
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(enc_fmt);
if (nb_samples > 0) {
ss << ",asetnsamples=n=" << nb_samples << ":p=0";
}
return ss.str();
}
}();
FilterGraph f{AVMEDIA_TYPE_AUDIO};
f.add_audio_src(src_fmt, {1, sample_rate}, sample_rate, channel_layout);
f.add_sink();
f.add_process(filter_desc);
f.create_filter();
return f;
} }
FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { FilterGraph get_video_filter_graph(
AVPixelFormat src_fmt,
AVRational rate,
int width,
int height,
AVPixelFormat enc_fmt,
bool is_cuda) {
auto desc = [&]() -> std::string { auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt || if (src_fmt == enc_fmt || is_cuda) {
codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
return "null"; return "null";
} else { } else {
std::stringstream ss; std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt); ss << "format=" << av_get_pix_fmt_name(enc_fmt);
return ss.str(); return ss.str();
} }
}(); }();
FilterGraph p{AVMEDIA_TYPE_VIDEO}; FilterGraph f{AVMEDIA_TYPE_VIDEO};
p.add_video_src( f.add_video_src(src_fmt, av_inv_q(rate), rate, width, height, {1, 1});
src_fmt, f.add_sink();
codec_ctx->time_base, f.add_process(desc);
codec_ctx->framerate, f.create_filter();
codec_ctx->width, return f;
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
p.add_sink();
p.add_process(desc);
p.create_filter();
return p;
} }
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { ////////////////////////////////////////////////////////////////////////////////
// Source frame
////////////////////////////////////////////////////////////////////////////////
AVFramePtr get_audio_frame(
AVSampleFormat format,
int sample_rate,
int num_channels,
uint64_t channel_layout,
int nb_samples) {
AVFramePtr frame{}; AVFramePtr frame{};
if (codec_ctx->hw_frames_ctx) { frame->format = format;
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0); frame->channel_layout = channel_layout;
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret)); frame->sample_rate = sample_rate;
} else { frame->nb_samples = nb_samples ? nb_samples : 1024;
frame->format = src_fmt; int ret = av_frame_get_buffer(frame, 0);
frame->width = codec_ctx->width; TORCH_CHECK(
frame->height = codec_ctx->height; ret >= 0, "Error allocating the source audio frame:", av_err2string(ret));
int ret = av_frame_get_buffer(frame, 0); // Note: `channels` attribute is not required for encoding, but
TORCH_CHECK( // TensorConverter refers to it
ret >= 0, frame->channels = num_channels;
"Error allocating a video buffer (", frame->pts = 0;
av_err2string(ret), return frame;
")."); }
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, int width, int height) {
AVFramePtr frame{};
frame->format = src_fmt;
frame->width = width;
frame->height = height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating a video buffer :", av_err2string(ret));
// Note: `nb_samples` attribute is not used for video, but we set it // Note: `nb_samples` attribute is not used for video, but we set it
// anyways so that we can make the logic of PTS increment agnostic to // anyways so that we can make the logic of PTS increment agnostic to
// audio and video. // audio and video.
...@@ -544,99 +612,164 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { ...@@ -544,99 +612,164 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
} // namespace } // namespace
EncodeProcess::EncodeProcess( ////////////////////////////////////////////////////////////////////////////////
// Finally, the extern-facing API
////////////////////////////////////////////////////////////////////////////////
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
int sample_rate, int src_sample_rate,
int num_channels, int src_num_channels,
const enum AVSampleFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) const c10::optional<EncodingConfig>& config) {
: codec_ctx(get_audio_codec( // 1. Check the source format, rate and channels
format_ctx->oformat, const AVSampleFormat src_fmt = get_sample_fmt(format);
sample_rate, TORCH_CHECK(
num_channels, src_sample_rate > 0,
encoder, "Sample rate must be positive. Found: ",
encoder_option, src_sample_rate);
encoder_format, TORCH_CHECK(
config)), src_num_channels > 0,
encoder(format_ctx, codec_ctx), "The number of channels must be positive. Found: ",
filter(get_audio_filter(format, codec_ctx)), src_num_channels);
src_frame(get_audio_frame(format, sample_rate, num_channels, codec_ctx)),
converter(AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples) {}
EncodeProcess::EncodeProcess( // 2. Fetch codec from default or override
TORCH_CHECK(
format_ctx->oformat->audio_codec != AV_CODEC_ID_NONE,
format_ctx->oformat->name,
" does not support audio.");
const AVCodec* codec = get_codec(format_ctx->oformat->audio_codec, encoder);
// 3. Check that encoding sample format, sample rate and channels
// TODO: introduce encoder_sampel_rate option and allow to change sample rate
const AVSampleFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
validate_sample_rate(src_sample_rate, codec);
uint64_t channel_layout = get_channel_layout(src_num_channels, codec);
// 4. Initialize codec context
AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags);
configure_audio_codec_ctx(
codec_ctx,
enc_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
config);
open_codec(codec_ctx, encoder_option);
// 5. Build filter graph
FilterGraph filter_graph = get_audio_filter_graph(
src_fmt, src_sample_rate, channel_layout, enc_fmt, codec_ctx->frame_size);
// 6. Instantiate source frame
AVFramePtr src_frame = get_audio_frame(
src_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
codec_ctx->frame_size);
// 7. Instantiate Converter
TensorConverter converter{
AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
// If anything after this throws, it will leave the StreamWriter in an
// invalid state.
Encoder enc{format_ctx, codec_ctx, get_stream(format_ctx, codec_ctx)};
return EncodeProcess{
std::move(converter),
std::move(src_frame),
std::move(filter_graph),
std::move(enc),
std::move(codec_ctx)};
}
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
double frame_rate, double frame_rate,
int width, int src_width,
int height, int src_height,
const enum AVPixelFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) const c10::optional<EncodingConfig>& config) {
: codec_ctx(get_video_codec( // 1. Checkc the source format, rate and resolution
format_ctx->oformat, const AVPixelFormat src_fmt = get_pix_fmt(format);
frame_rate, AVRational src_rate = av_d2q(frame_rate, 1 << 24);
width, TORCH_CHECK(
height, src_rate.num > 0 && src_rate.den != 0,
encoder, "Frame rate must be positive and finite. Found: ",
encoder_option, frame_rate);
encoder_format, TORCH_CHECK(src_width > 0, "width must be positive. Found: ", src_width);
hw_accel, TORCH_CHECK(src_height > 0, "height must be positive. Found: ", src_height);
config)),
encoder(format_ctx, codec_ctx),
filter(get_video_filter(format, codec_ctx)),
src_frame(get_video_frame(format, codec_ctx)),
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}
void EncodeProcess::process( // 2. Fetch codec from default or override
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts) {
TORCH_CHECK( TORCH_CHECK(
codec_ctx->codec_type == type, format_ctx->oformat->video_codec != AV_CODEC_ID_NONE,
"Attempted to write ", format_ctx->oformat->name,
av_get_media_type_string(type), " does not support video.");
" to ", const AVCodec* codec = get_codec(format_ctx->oformat->video_codec, encoder);
av_get_media_type_string(codec_ctx->codec_type),
" stream."); // 3. Check that encoding format, rate
if (pts) { const AVPixelFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
AVRational tb = codec_ctx->time_base; validate_frame_rate(src_rate, codec);
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
if (src_frame->pts > val) { // 4. Initialize codec context
TORCH_WARN_ONCE( AVCodecContextPtr codec_ctx =
"The provided PTS value is smaller than the next expected value."); get_codec_ctx(codec, format_ctx->oformat->flags);
} configure_video_codec_ctx(
src_frame->pts = val; codec_ctx, enc_fmt, src_rate, src_width, src_height, config);
} if (hw_accel) {
for (const auto& frame : converter.convert(tensor)) { #ifdef USE_CUDA
process_frame(frame); configure_hw_accel(codec_ctx, hw_accel.value());
frame->pts += frame->nb_samples; #else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. ",
"Hardware acceleration is not available.");
#endif
} }
} open_codec(codec_ctx, encoder_option);
void EncodeProcess::process_frame(AVFrame* src) { // 5. Build filter graph
int ret = filter.add_frame(src); FilterGraph filter_graph = get_video_filter_graph(
while (ret >= 0) { src_fmt, src_rate, src_width, src_height, enc_fmt, hw_accel.has_value());
ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { // 6. Instantiate source frame
if (ret == AVERROR_EOF) { AVFramePtr src_frame = [&]() {
encoder.encode(nullptr); if (codec_ctx->hw_frames_ctx) {
} AVFramePtr frame{};
break; int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame;
} }
if (ret >= 0) { return get_video_frame(src_fmt, src_width, src_height);
encoder.encode(dst_frame); }();
}
av_frame_unref(dst_frame);
}
}
void EncodeProcess::flush() { // 7. Converter
process_frame(nullptr); TensorConverter converter{AVMEDIA_TYPE_VIDEO, src_frame};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
// If anything after this throws, it will leave the StreamWriter in an
// invalid state.
Encoder enc{format_ctx, codec_ctx, get_stream(format_ctx, codec_ctx)};
return EncodeProcess{
std::move(converter),
std::move(src_frame),
std::move(filter_graph),
std::move(enc),
std::move(codec_ctx)};
} }
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -9,47 +9,50 @@ ...@@ -9,47 +9,50 @@
namespace torchaudio::io { namespace torchaudio::io {
class EncodeProcess { class EncodeProcess {
// In the reverse order of the process
AVCodecContextPtr codec_ctx;
Encoder encoder;
AVFramePtr dst_frame{};
FilterGraph filter;
AVFramePtr src_frame;
TensorConverter converter; TensorConverter converter;
AVFramePtr src_frame;
FilterGraph filter;
AVFramePtr dst_frame{};
Encoder encoder;
AVCodecContextPtr codec_ctx;
public: public:
// Constructor for audio
EncodeProcess(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const enum AVSampleFormat format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
// constructor for video
EncodeProcess( EncodeProcess(
AVFormatContext* format_ctx, TensorConverter&& converter,
double frame_rate, AVFramePtr&& frame,
int width, FilterGraph&& filter_graph,
int height, Encoder&& encoder,
const enum AVPixelFormat format, AVCodecContextPtr&& codec_ctx) noexcept;
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, EncodeProcess(EncodeProcess&&) noexcept = default;
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
const c10::optional<EncodingConfig>& config);
void process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts);
void process_frame(AVFrame* src); void process_frame(AVFrame* src);
void flush(); void flush();
}; };
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx,
int sample_rate,
int num_channels,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config);
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx,
double frame_rate,
int width,
int height,
const std::string& format,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config);
}; // namespace torchaudio::io }; // namespace torchaudio::io
...@@ -2,28 +2,11 @@ ...@@ -2,28 +2,11 @@
namespace torchaudio::io { namespace torchaudio::io {
namespace { Encoder::Encoder(
AVFormatContext* format_ctx,
AVStream* add_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) { AVCodecContext* codec_ctx,
AVStream* stream = avformat_new_stream(format_ctx, nullptr); AVStream* stream) noexcept
TORCH_CHECK(stream, "Failed to allocate stream."); : format_ctx(format_ctx), codec_ctx(codec_ctx), stream(stream) {}
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
} // namespace
Encoder::Encoder(AVFormatContext* format_ctx_, AVCodecContext* codec_ctx_)
: format_ctx(format_ctx_),
codec_ctx(codec_ctx_),
stream(add_stream(format_ctx, codec_ctx)) {}
/// ///
/// Encode the given AVFrame data /// Encode the given AVFrame data
......
...@@ -12,16 +12,17 @@ class Encoder { ...@@ -12,16 +12,17 @@ class Encoder {
AVFormatContext* format_ctx; AVFormatContext* format_ctx;
// Reference to codec context (encoder) // Reference to codec context (encoder)
AVCodecContext* codec_ctx; AVCodecContext* codec_ctx;
// Stream object // Stream object as reference. Owned by AVFormatContext.
// Encoder object creates AVStream, but it will be deallocated along with
// AVFormatContext, So Encoder does not own it.
AVStream* stream; AVStream* stream;
// Temporary object used during the encoding // Temporary object used during the encoding
// Encoder owns it. // Encoder owns it.
AVPacketPtr packet{}; AVPacketPtr packet{};
public: public:
Encoder(AVFormatContext* format_ctx, AVCodecContext* codec_ctx); Encoder(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
AVStream* stream) noexcept;
void encode(AVFrame* frame); void encode(AVFrame* frame);
}; };
......
...@@ -53,88 +53,54 @@ StreamWriter::StreamWriter( ...@@ -53,88 +53,54 @@ StreamWriter::StreamWriter(
const c10::optional<std::string>& format) const c10::optional<std::string>& format)
: StreamWriter(get_output_format_context(dst, format, nullptr)) {} : StreamWriter(get_output_format_context(dst, format, nullptr)) {}
namespace {
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
TORCH_CHECK(fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", src);
TORCH_CHECK(
!av_sample_fmt_is_planar(fmt),
"Unexpected sample fotmat value. Valid values are ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_U8),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S16),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S32),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S64),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_DBL),
". ",
"Found: ",
src);
return fmt;
}
enum AVPixelFormat get_src_pixel_fmt(const std::string& src) {
auto fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
case AV_PIX_FMT_NONE:
TORCH_CHECK(false, "Unknown pixel format: ", src);
default:
TORCH_CHECK(false, "Unsupported pixel format: ", src);
}
}
} // namespace
void StreamWriter::add_audio_stream( void StreamWriter::add_audio_stream(
int64_t sample_rate, int sample_rate,
int64_t num_channels, int num_channels,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) { const c10::optional<EncodingConfig>& config) {
processes.emplace_back( TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
pFormatContext, pFormatContext,
sample_rate, sample_rate,
num_channels, num_channels,
get_src_sample_fmt(format), format,
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
config); config));
} }
void StreamWriter::add_video_stream( void StreamWriter::add_video_stream(
double frame_rate, double frame_rate,
int64_t width, int width,
int64_t height, int height,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) { const c10::optional<EncodingConfig>& config) {
processes.emplace_back( TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
pFormatContext, pFormatContext,
frame_rate, frame_rate,
width, width,
height, height,
get_src_pixel_fmt(format), format,
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
hw_accel, hw_accel,
config); config));
} }
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
...@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) { ...@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) {
} }
void StreamWriter::open(const c10::optional<OptionDict>& option) { void StreamWriter::open(const c10::optional<OptionDict>& option) {
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
int ret = 0; int ret = 0;
// Open the file if it was not provided by client code (i.e. when not // Open the file if it was not provided by client code (i.e. when not
...@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk( ...@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk(
const c10::optional<double>& pts) { const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?"); TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
processes.size(), pFormatContext->nb_streams,
"). Found: ", "). Found: ",
i); i);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts); TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO,
"Stream ",
i,
" is not audio type.");
processes[i].process(waveform, pts);
} }
void StreamWriter::write_video_chunk( void StreamWriter::write_video_chunk(
...@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk( ...@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk(
const c10::optional<double>& pts) { const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?"); TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
processes.size(), pFormatContext->nb_streams,
"). Found: ", "). Found: ",
i); i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts); TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_VIDEO,
"Stream ",
i,
" is not video type.");
processes[i].process(frames, pts);
} }
void StreamWriter::flush() { void StreamWriter::flush() {
......
...@@ -100,8 +100,8 @@ class StreamWriter { ...@@ -100,8 +100,8 @@ class StreamWriter {
/// To list supported formats for the encoder, you can use /// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command. /// ``ffmpeg -h encoder=<ENCODER>`` command.
void add_audio_stream( void add_audio_stream(
int64_t sample_rate, int sample_rate,
int64_t num_channels, int num_channels,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
...@@ -139,8 +139,8 @@ class StreamWriter { ...@@ -139,8 +139,8 @@ class StreamWriter {
/// @endparblock /// @endparblock
void add_video_stream( void add_video_stream(
double frame_rate, double frame_rate,
int64_t width, int width,
int64_t height, int height,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment