Commit 4eac61a3 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Refactor the initialization of EncodeProcess (#3205)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3205

This commit refactors the initialization of EncodeProcess.

Interface-wise, the signature of the constructor of EncodeProcess
has made simpler just to take rvalues of its components, and the
initialization of the components have been moved to helper functions.

Implementat-wise, the order that the components are initialized is
revised, and the source of initialization parameters is also revised.

For example, the original implementation first creates AVCodecContext,
and passes it around to create the other components. This relied on
an assumption that parameters AVCodecContext has (such as image size
and sample rate) are same as the source data. This is not always right,
and as we will introduce custom filter graph and allow on-the-fly
transform of rates and dimensions, it will become even less correct.

The new initialization constructs source AVFrame, TensorConverter and
FilterGraph from source attributes. This makes it easy to introduce
on-the-fly transform.

Reviewed By: nateanl

Differential Revision: D44360650

fbshipit-source-id: bf0e77dc1a5a40fc8e9870c50d07339d812762e8
parent d8a37a21
...@@ -2,180 +2,151 @@ ...@@ -2,180 +2,151 @@
namespace torchaudio::io { namespace torchaudio::io {
namespace { ////////////////////////////////////////////////////////////////////////////////
// EncodeProcess Logic Implementation
AVCodecContextPtr get_codec_ctx( ////////////////////////////////////////////////////////////////////////////////
enum AVMediaType type,
AVFORMAT_CONST AVOutputFormat* oformat,
const c10::optional<std::string>& encoder) {
enum AVCodecID default_codec = [&]() {
switch (type) {
case AVMEDIA_TYPE_AUDIO:
return oformat->audio_codec;
case AVMEDIA_TYPE_VIDEO:
return oformat->video_codec;
default:
TORCH_CHECK(
false, "Unsupported media type: ", av_get_media_type_string(type));
}
}();
TORCH_CHECK( EncodeProcess::EncodeProcess(
default_codec != AV_CODEC_ID_NONE, TensorConverter&& converter,
"Format \"", AVFramePtr&& frame,
oformat->name, FilterGraph&& filter_graph,
"\" does not support ", Encoder&& encoder,
av_get_media_type_string(type), AVCodecContextPtr&& codec_ctx) noexcept
"."); : converter(std::move(converter)),
src_frame(std::move(frame)),
filter(std::move(filter_graph)),
encoder(std::move(encoder)),
codec_ctx(std::move(codec_ctx)) {}
const AVCodec* codec = [&]() { void EncodeProcess::process(
if (encoder) { const torch::Tensor& tensor,
const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str()); const c10::optional<double>& pts) {
TORCH_CHECK(c, "Unexpected codec: ", encoder.value()); if (pts) {
return c; AVRational tb = codec_ctx->time_base;
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
if (src_frame->pts > val) {
TORCH_WARN_ONCE(
"The provided PTS value is smaller than the next expected value.");
} }
const AVCodec* c = avcodec_find_encoder(default_codec); src_frame->pts = val;
TORCH_CHECK( }
c, "Encoder not found for codec: ", avcodec_get_name(default_codec)); for (const auto& frame : converter.convert(tensor)) {
return c; process_frame(frame);
}(); frame->pts += frame->nb_samples;
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (oformat->flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
} }
return AVCodecContextPtr(ctx);
} }
std::vector<int> get_supported_sample_rates(const AVCodec* codec) { void EncodeProcess::process_frame(AVFrame* src) {
std::vector<int> ret; int ret = filter.add_frame(src);
if (codec->supported_samplerates) { while (ret >= 0) {
const int* t = codec->supported_samplerates; ret = filter.get_frame(dst_frame);
while (*t) { if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) {
ret.push_back(*t); if (ret == AVERROR_EOF) {
++t; encoder.encode(nullptr);
} }
break;
} }
return ret; if (ret >= 0) {
} encoder.encode(dst_frame);
std::vector<std::string> get_supported_sample_fmts(const AVCodec* codec) {
std::vector<std::string> ret;
if (codec->sample_fmts) {
const enum AVSampleFormat* t = codec->sample_fmts;
while (*t != AV_SAMPLE_FMT_NONE) {
ret.emplace_back(av_get_sample_fmt_name(*t));
++t;
} }
av_frame_unref(dst_frame);
} }
return ret;
} }
std::vector<uint64_t> get_supported_channel_layouts(const AVCodec* codec) { void EncodeProcess::flush() {
std::vector<uint64_t> ret; process_frame(nullptr);
if (codec->channel_layouts) {
const uint64_t* t = codec->channel_layouts;
while (*t) {
ret.push_back(*t);
++t;
}
}
return ret;
} }
void configure_audio_codec( ////////////////////////////////////////////////////////////////////////////////
AVCodecContextPtr& ctx, // EncodeProcess Initialization helper functions
int64_t sample_rate, ////////////////////////////////////////////////////////////////////////////////
int64_t num_channels,
const c10::optional<std::string>& format,
const c10::optional<EncodingConfig>& config) {
// TODO: Review options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00122
// - bit_rate
// - bit_rate_tolerance
ctx->sample_rate = [&]() -> int { namespace {
auto rates = get_supported_sample_rates(ctx->codec);
if (rates.empty()) { enum AVSampleFormat get_sample_fmt(const std::string& src) {
return static_cast<int>(sample_rate); auto fmt = av_get_sample_fmt(src.c_str());
} if (fmt != AV_SAMPLE_FMT_NONE && !av_sample_fmt_is_planar(fmt)) {
for (const auto& it : rates) { return fmt;
if (it == sample_rate) {
return static_cast<int>(sample_rate);
}
} }
TORCH_CHECK( TORCH_CHECK(
false, false,
ctx->codec->name, "Unsupported sample fotmat (",
" does not support sample rate ", src,
sample_rate, ") was provided. Valid values are ",
". Supported sample rates are: ", []() -> std::string {
c10::Join(", ", rates)); std::vector<std::string> ret;
}(); for (const auto& fmt :
ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24)); {AV_SAMPLE_FMT_U8,
ctx->sample_fmt = [&]() { AV_SAMPLE_FMT_S16,
// Use default AV_SAMPLE_FMT_S32,
if (!format) { AV_SAMPLE_FMT_S64,
TORCH_CHECK( AV_SAMPLE_FMT_FLT,
ctx->codec->sample_fmts, AV_SAMPLE_FMT_DBL}) {
ctx->codec->name, ret.emplace_back(av_get_sample_fmt_name(fmt));
" does not have default sample format. Please specify one."); }
return ctx->codec->sample_fmts[0]; return c10::Join(", ", ret);
} }(),
// Use the given one. ".");
auto fmt = format.value(); }
auto ret = av_get_sample_fmt(fmt.c_str());
auto fmts = get_supported_sample_fmts(ctx->codec);
if (fmts.empty()) {
TORCH_CHECK(
ret != AV_SAMPLE_FMT_NONE, "Unrecognized format: ", fmt, ". ");
return ret;
}
TORCH_CHECK(
std::count(fmts.begin(), fmts.end(), fmt),
"Unsupported sample format: ",
fmt,
". Supported values are ",
c10::Join(", ", fmts));
return ret;
}();
// validate and set channels enum AVPixelFormat get_pix_fmt(const std::string& src) {
ctx->channels = static_cast<int>(num_channels); AVPixelFormat fmt = av_get_pix_fmt(src.c_str());
auto layout = av_get_default_channel_layout(ctx->channels); switch (fmt) {
auto layouts = get_supported_channel_layouts(ctx->codec); case AV_PIX_FMT_GRAY8:
if (!layouts.empty()) { case AV_PIX_FMT_RGB24:
if (!std::count(layouts.begin(), layouts.end(), layout)) { case AV_PIX_FMT_BGR24:
std::vector<std::string> tmp; case AV_PIX_FMT_YUV444P:
for (const auto& it : layouts) { return fmt;
tmp.push_back(std::to_string(av_get_channel_layout_nb_channels(it))); default:;
} }
TORCH_CHECK( TORCH_CHECK(
false, false,
"Unsupported channels: ", "Unsupported pixel format (",
num_channels, src,
". Supported channels are: ", ") was provided. Valid values are ",
c10::Join(", ", tmp)); []() -> std::string {
} std::vector<std::string> ret;
} for (const auto& fmt :
ctx->channel_layout = static_cast<uint64_t>(layout); {AV_PIX_FMT_GRAY8,
AV_PIX_FMT_RGB24,
AV_PIX_FMT_BGR24,
AV_PIX_FMT_YUV444P}) {
ret.emplace_back(av_get_pix_fmt_name(fmt));
}
return c10::Join(", ", ret);
}(),
".");
}
// Set optional stuff ////////////////////////////////////////////////////////////////////////////////
if (config) { // Codec & Codec context
auto& cfg = config.value(); ////////////////////////////////////////////////////////////////////////////////
if (cfg.bit_rate > 0) { const AVCodec* get_codec(
ctx->bit_rate = cfg.bit_rate; AVCodecID default_codec,
} const c10::optional<std::string>& encoder) {
if (cfg.compression_level != -1) { if (encoder) {
ctx->compression_level = cfg.compression_level; const AVCodec* c = avcodec_find_encoder_by_name(encoder.value().c_str());
TORCH_CHECK(c, "Unexpected codec: ", encoder.value());
return c;
} }
const AVCodec* c = avcodec_find_encoder(default_codec);
TORCH_CHECK(
c, "Encoder not found for codec: ", avcodec_get_name(default_codec));
return c;
}
AVCodecContextPtr get_codec_ctx(const AVCodec* codec, int flags) {
AVCodecContext* ctx = avcodec_alloc_context3(codec);
TORCH_CHECK(ctx, "Failed to allocate CodecContext.");
if (flags & AVFMT_GLOBALHEADER) {
ctx->flags |= AV_CODEC_FLAG_GLOBAL_HEADER;
} }
return AVCodecContextPtr(ctx);
} }
void open_codec( void open_codec(
AVCodecContextPtr& codec_ctx, AVCodecContext* codec_ctx,
const c10::optional<OptionDict>& option) { const c10::optional<OptionDict>& option) {
AVDictionary* opt = get_option_dict(option); AVDictionary* opt = get_option_dict(option);
...@@ -214,186 +185,242 @@ void open_codec( ...@@ -214,186 +185,242 @@ void open_codec(
TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")"); TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")");
} }
AVCodecContextPtr get_audio_codec( ////////////////////////////////////////////////////////////////////////////////
AVFORMAT_CONST AVOutputFormat* oformat, // Audio codec
int64_t sample_rate, ////////////////////////////////////////////////////////////////////////////////
int64_t num_channels,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format, config);
open_codec(ctx, encoder_option);
return ctx;
}
FilterGraph get_audio_filter( bool supported_sample_fmt(
AVSampleFormat src_fmt, const AVSampleFormat fmt,
AVCodecContext* codec_ctx) { const AVSampleFormat* sample_fmts) {
auto desc = [&]() -> std::string { if (!sample_fmts) {
if (src_fmt == codec_ctx->sample_fmt) { return true;
if (!codec_ctx->frame_size) {
return "anull";
} else {
std::stringstream ss;
ss << "asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
return ss.str();
} }
} else { while (*sample_fmts != AV_SAMPLE_FMT_NONE) {
std::stringstream ss; if (fmt == *sample_fmts) {
ss << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt); return true;
if (codec_ctx->frame_size) {
ss << ",asetnsamples=n=" << codec_ctx->frame_size << ":p=0";
} }
return ss.str(); ++sample_fmts;
} }
}(); return false;
}
FilterGraph p{AVMEDIA_TYPE_AUDIO}; std::vector<std::string> get_supported_formats(
p.add_audio_src( const AVSampleFormat* sample_fmts) {
src_fmt, std::vector<std::string> ret;
codec_ctx->time_base, while (*sample_fmts != AV_SAMPLE_FMT_NONE) {
codec_ctx->sample_rate, ret.emplace_back(av_get_sample_fmt_name(*sample_fmts));
codec_ctx->channel_layout); ++sample_fmts;
p.add_sink(); }
p.add_process(desc); return ret;
p.create_filter();
return p;
} }
AVFramePtr get_audio_frame( AVSampleFormat get_enc_fmt(
AVSampleFormat src_fmt, AVSampleFormat src_fmt,
int sample_rate, const c10::optional<std::string>& encoder_format,
int num_channels, const AVCodec* codec) {
AVCodecContext* codec_ctx, if (encoder_format) {
int default_frame_size = 10000) { auto& enc_fmt_val = encoder_format.value();
AVFramePtr frame{}; auto fmt = av_get_sample_fmt(enc_fmt_val.c_str());
frame->pts = 0;
frame->format = src_fmt;
// Note: `channels` attribute is not required for encoding, but
// TensorConverter refers to it
frame->channels = num_channels;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK( TORCH_CHECK(
ret >= 0, fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", enc_fmt_val);
"Error allocating an audio buffer (", TORCH_CHECK(
av_err2string(ret), supported_sample_fmt(fmt, codec->sample_fmts),
")."); codec->name,
" does not support ",
encoder_format.value(),
" format. Supported values are; ",
c10::Join(", ", get_supported_formats(codec->sample_fmts)));
return fmt;
} }
return frame; if (codec->sample_fmts) {
} return codec->sample_fmts[0];
}
return src_fmt;
};
std::vector<std::string> get_supported_pix_fmts(const AVCodec* codec) { bool supported_sample_rate(
std::vector<std::string> ret; const int sample_rate,
if (codec->pix_fmts) { const int* supported_samplerates) {
const enum AVPixelFormat* t = codec->pix_fmts; if (!supported_samplerates) {
while (*t != AV_PIX_FMT_NONE) { return true;
ret.emplace_back(av_get_pix_fmt_name(*t));
++t;
} }
while (*supported_samplerates) {
if (sample_rate == *supported_samplerates) {
return true;
} }
return ret; ++supported_samplerates;
}
return false;
} }
std::vector<AVRational> get_supported_frame_rates(const AVCodec* codec) { std::vector<int> get_supported_samplerates(const int* supported_samplerates) {
std::vector<AVRational> ret; std::vector<int> ret;
if (codec->supported_framerates) { if (supported_samplerates) {
const AVRational* t = codec->supported_framerates; while (*supported_samplerates) {
while (!(t->num == 0 && t->den == 0)) { ret.push_back(*supported_samplerates);
ret.push_back(*t); ++supported_samplerates;
++t;
} }
} }
return ret; return ret;
} }
// used to compare frame rate / sample rate. void validate_sample_rate(int sample_rate, const AVCodec* codec) {
// not a general purpose float comparison TORCH_CHECK(
bool is_rate_close(double rate, AVRational rational) { supported_sample_rate(sample_rate, codec->supported_samplerates),
double ref = codec->name,
static_cast<double>(rational.num) / static_cast<double>(rational.den); " does not support sample rate ",
// frame rates / sample rates sample_rate,
static const double threshold = 0.001; ". Supported values are; ",
return fabs(rate - ref) < threshold; c10::Join(", ", get_supported_samplerates(codec->supported_samplerates)));
} }
void configure_video_codec( std::vector<std::string> get_supported_channels(
AVCodecContextPtr& ctx, const uint64_t* channel_layouts) {
double frame_rate, std::vector<std::string> ret;
int64_t width, while (*channel_layouts) {
int64_t height, ret.emplace_back(av_get_channel_name(*channel_layouts));
const c10::optional<std::string>& format, ++channel_layouts;
const c10::optional<EncodingConfig>& config) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate
// - bit_rate_tolerance
// - gop_size
// - max_b_frames
// - mb_decisions
ctx->width = static_cast<int>(width);
ctx->height = static_cast<int>(height);
ctx->time_base = [&]() {
AVRational ret = av_inv_q(av_d2q(frame_rate, 1 << 24));
auto rates = get_supported_frame_rates(ctx->codec);
// Codec does not have constraint on frame rate
if (rates.empty()) {
return ret;
} }
// Codec has list of supported frame rate.
for (const auto& t : rates) {
if (is_rate_close(frame_rate, t)) {
return ret; return ret;
}
uint64_t get_channel_layout(int num_channels, const AVCodec* codec) {
if (!codec->channel_layouts) {
return static_cast<uint64_t>(av_get_default_channel_layout(num_channels));
} }
for (const uint64_t* it = codec->channel_layouts; *it; ++it) {
if (av_get_channel_layout_nb_channels(*it) == num_channels) {
return *it;
} }
// Given one is not supported.
std::vector<std::string> tmp;
for (const auto& t : rates) {
tmp.emplace_back(
t.den == 1 ? std::to_string(t.num)
: std::to_string(t.num) + "/" + std::to_string(t.den));
} }
TORCH_CHECK( TORCH_CHECK(
false, false,
"Unsupported frame rate: ", "Codec ",
frame_rate, codec->name,
". Supported values are ", " does not support a channel layout consists of ",
c10::Join(", ", tmp)); num_channels,
}(); " channels. Supported values are: ",
ctx->pix_fmt = [&]() { c10::Join(", ", get_supported_channels(codec->channel_layouts)));
// Use default }
if (!format) {
TORCH_CHECK( void configure_audio_codec_ctx(
ctx->codec->pix_fmts, AVCodecContext* codec_ctx,
ctx->codec->name, AVSampleFormat format,
" does not have defaut pixel format. Please specify one."); int sample_rate,
return ctx->codec->pix_fmts[0]; int num_channels,
} uint64_t channel_layout,
// Use the given one, const c10::optional<EncodingConfig>& config) {
auto fmt = format.value(); codec_ctx->sample_fmt = format;
auto ret = av_get_pix_fmt(fmt.c_str()); codec_ctx->sample_rate = sample_rate;
auto fmts = get_supported_pix_fmts(ctx->codec); codec_ctx->time_base = av_inv_q(av_d2q(sample_rate, 1 << 24));
if (fmts.empty()) { codec_ctx->channels = num_channels;
TORCH_CHECK(ret != AV_PIX_FMT_NONE, "Unrecognized format: ", fmt, ". "); codec_ctx->channel_layout = channel_layout;
// Set optional stuff
if (config) {
auto& cfg = config.value();
if (cfg.bit_rate > 0) {
codec_ctx->bit_rate = cfg.bit_rate;
}
if (cfg.compression_level != -1) {
codec_ctx->compression_level = cfg.compression_level;
}
}
}
////////////////////////////////////////////////////////////////////////////////
// Video codec
////////////////////////////////////////////////////////////////////////////////
bool supported_pix_fmt(const AVPixelFormat fmt, const AVPixelFormat* pix_fmts) {
if (!pix_fmts) {
return true;
}
while (*pix_fmts != AV_PIX_FMT_NONE) {
if (fmt == *pix_fmts) {
return true;
}
++pix_fmts;
}
return false;
}
std::vector<std::string> get_supported_formats(const AVPixelFormat* pix_fmts) {
std::vector<std::string> ret;
while (*pix_fmts != AV_PIX_FMT_NONE) {
ret.emplace_back(av_get_pix_fmt_name(*pix_fmts));
++pix_fmts;
}
return ret; return ret;
}
AVPixelFormat get_enc_fmt(
AVPixelFormat src_fmt,
const c10::optional<std::string>& encoder_format,
const AVCodec* codec) {
if (encoder_format) {
auto fmt = get_pix_fmt(encoder_format.value());
TORCH_CHECK(
supported_pix_fmt(fmt, codec->pix_fmts),
codec->name,
" does not support ",
encoder_format.value(),
" format. Supported values are; ",
c10::Join(", ", get_supported_formats(codec->pix_fmts)));
return fmt;
}
if (codec->pix_fmts) {
return codec->pix_fmts[0];
}
return src_fmt;
}
bool supported_frame_rate(AVRational rate, const AVRational* rates) {
if (!rates) {
return true;
} }
if (!std::count(fmts.begin(), fmts.end(), fmt)) { for (; !(rates->num == 0 && rates->den == 0); ++rates) {
if (av_cmp_q(rate, *rates) == 0) {
return true;
}
}
return false;
}
void validate_frame_rate(AVRational rate, const AVCodec* codec) {
TORCH_CHECK( TORCH_CHECK(
false, supported_frame_rate(rate, codec->supported_framerates),
"Unsupported pixel format: ", codec->name,
fmt, " does not support frame rate ",
". Supported values are ", c10::Join("/", std::array<int, 2>{rate.num, rate.den}),
c10::Join(", ", fmts)); ". Supported values are; ",
[&]() {
std::vector<std::string> ret;
for (auto r = codec->supported_framerates;
!(r->num == 0 && r->den == 0);
++r) {
ret.push_back(c10::Join("/", std::array<int, 2>{r->num, r->den}));
} }
return ret; return c10::Join(", ", ret);
}(); }());
}
void configure_video_codec_ctx(
AVCodecContextPtr& ctx,
AVPixelFormat format,
AVRational frame_rate,
int width,
int height,
const c10::optional<EncodingConfig>& config) {
// TODO: Review other options and make them configurable?
// https://ffmpeg.org/doxygen/4.1/muxing_8c_source.html#l00147
// - bit_rate_tolerance
// - mb_decisions
ctx->pix_fmt = format;
ctx->width = width;
ctx->height = height;
ctx->time_base = av_inv_q(frame_rate);
// Set optional stuff // Set optional stuff
if (config) { if (config) {
...@@ -416,9 +443,9 @@ void configure_video_codec( ...@@ -416,9 +443,9 @@ void configure_video_codec(
void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) { void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
torch::Device device{hw_accel}; torch::Device device{hw_accel};
TORCH_CHECK( TORCH_CHECK(
device.type() == c10::DeviceType::CUDA, device.is_cuda(),
"Only CUDA is supported for hardware acceleration. Found: ", "Only CUDA is supported for hardware acceleration. Found: ",
device.str()); device);
// NOTES: // NOTES:
// 1. Examples like // 1. Examples like
...@@ -463,77 +490,118 @@ void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) { ...@@ -463,77 +490,118 @@ void configure_hw_accel(AVCodecContext* ctx, const std::string& hw_accel) {
av_err2string(ret)); av_err2string(ret));
} }
AVCodecContextPtr get_video_codec( ////////////////////////////////////////////////////////////////////////////////
AVFORMAT_CONST AVOutputFormat* oformat, // AVStream
double frame_rate, ////////////////////////////////////////////////////////////////////////////////
int64_t width,
int64_t height,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format, config);
if (hw_accel) { AVStream* get_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) {
#ifdef USE_CUDA AVStream* stream = avformat_new_stream(format_ctx, nullptr);
configure_hw_accel(ctx, hw_accel.value()); TORCH_CHECK(stream, "Failed to allocate stream.");
#else
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK( TORCH_CHECK(
false, ret >= 0, "Failed to copy the stream parameter: ", av_err2string(ret));
"torchaudio is not compiled with CUDA support. ", return stream;
"Hardware acceleration is not available."); }
#endif
////////////////////////////////////////////////////////////////////////////////
// FilterGraph
////////////////////////////////////////////////////////////////////////////////
FilterGraph get_audio_filter_graph(
AVSampleFormat src_fmt,
int sample_rate,
uint64_t channel_layout,
AVSampleFormat enc_fmt,
int nb_samples) {
const std::string filter_desc = [&]() -> const std::string {
if (src_fmt == enc_fmt) {
if (nb_samples == 0) {
return "anull";
} else {
std::stringstream ss;
ss << "asetnsamples=n=" << nb_samples << ":p=0";
return ss.str();
} }
} else {
std::stringstream ss;
ss << "aformat=" << av_get_sample_fmt_name(enc_fmt);
if (nb_samples > 0) {
ss << ",asetnsamples=n=" << nb_samples << ":p=0";
}
return ss.str();
}
}();
open_codec(ctx, encoder_option); FilterGraph f{AVMEDIA_TYPE_AUDIO};
return ctx; f.add_audio_src(src_fmt, {1, sample_rate}, sample_rate, channel_layout);
f.add_sink();
f.add_process(filter_desc);
f.create_filter();
return f;
} }
FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { FilterGraph get_video_filter_graph(
AVPixelFormat src_fmt,
AVRational rate,
int width,
int height,
AVPixelFormat enc_fmt,
bool is_cuda) {
auto desc = [&]() -> std::string { auto desc = [&]() -> std::string {
if (src_fmt == codec_ctx->pix_fmt || if (src_fmt == enc_fmt || is_cuda) {
codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
return "null"; return "null";
} else { } else {
std::stringstream ss; std::stringstream ss;
ss << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt); ss << "format=" << av_get_pix_fmt_name(enc_fmt);
return ss.str(); return ss.str();
} }
}(); }();
FilterGraph p{AVMEDIA_TYPE_VIDEO}; FilterGraph f{AVMEDIA_TYPE_VIDEO};
p.add_video_src( f.add_video_src(src_fmt, av_inv_q(rate), rate, width, height, {1, 1});
src_fmt, f.add_sink();
codec_ctx->time_base, f.add_process(desc);
codec_ctx->framerate, f.create_filter();
codec_ctx->width, return f;
codec_ctx->height, }
codec_ctx->sample_aspect_ratio);
p.add_sink(); ////////////////////////////////////////////////////////////////////////////////
p.add_process(desc); // Source frame
p.create_filter(); ////////////////////////////////////////////////////////////////////////////////
return p;
AVFramePtr get_audio_frame(
AVSampleFormat format,
int sample_rate,
int num_channels,
uint64_t channel_layout,
int nb_samples) {
AVFramePtr frame{};
frame->format = format;
frame->channel_layout = channel_layout;
frame->sample_rate = sample_rate;
frame->nb_samples = nb_samples ? nb_samples : 1024;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating the source audio frame:", av_err2string(ret));
// Note: `channels` attribute is not required for encoding, but
// TensorConverter refers to it
frame->channels = num_channels;
frame->pts = 0;
return frame;
} }
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { AVFramePtr get_video_frame(AVPixelFormat src_fmt, int width, int height) {
AVFramePtr frame{}; AVFramePtr frame{};
if (codec_ctx->hw_frames_ctx) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt; frame->format = src_fmt;
frame->width = codec_ctx->width; frame->width = width;
frame->height = codec_ctx->height; frame->height = height;
int ret = av_frame_get_buffer(frame, 0); int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK( TORCH_CHECK(
ret >= 0, ret >= 0, "Error allocating a video buffer :", av_err2string(ret));
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
// Note: `nb_samples` attribute is not used for video, but we set it // Note: `nb_samples` attribute is not used for video, but we set it
// anyways so that we can make the logic of PTS increment agnostic to // anyways so that we can make the logic of PTS increment agnostic to
// audio and video. // audio and video.
...@@ -544,99 +612,164 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { ...@@ -544,99 +612,164 @@ AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
} // namespace } // namespace
EncodeProcess::EncodeProcess( ////////////////////////////////////////////////////////////////////////////////
// Finally, the extern-facing API
////////////////////////////////////////////////////////////////////////////////
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
int sample_rate, int src_sample_rate,
int num_channels, int src_num_channels,
const enum AVSampleFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) const c10::optional<EncodingConfig>& config) {
: codec_ctx(get_audio_codec( // 1. Check the source format, rate and channels
format_ctx->oformat, const AVSampleFormat src_fmt = get_sample_fmt(format);
sample_rate, TORCH_CHECK(
num_channels, src_sample_rate > 0,
encoder, "Sample rate must be positive. Found: ",
encoder_option, src_sample_rate);
encoder_format, TORCH_CHECK(
config)), src_num_channels > 0,
encoder(format_ctx, codec_ctx), "The number of channels must be positive. Found: ",
filter(get_audio_filter(format, codec_ctx)), src_num_channels);
src_frame(get_audio_frame(format, sample_rate, num_channels, codec_ctx)),
converter(AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples) {}
EncodeProcess::EncodeProcess( // 2. Fetch codec from default or override
TORCH_CHECK(
format_ctx->oformat->audio_codec != AV_CODEC_ID_NONE,
format_ctx->oformat->name,
" does not support audio.");
const AVCodec* codec = get_codec(format_ctx->oformat->audio_codec, encoder);
// 3. Check that encoding sample format, sample rate and channels
// TODO: introduce encoder_sampel_rate option and allow to change sample rate
const AVSampleFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
validate_sample_rate(src_sample_rate, codec);
uint64_t channel_layout = get_channel_layout(src_num_channels, codec);
// 4. Initialize codec context
AVCodecContextPtr codec_ctx =
get_codec_ctx(codec, format_ctx->oformat->flags);
configure_audio_codec_ctx(
codec_ctx,
enc_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
config);
open_codec(codec_ctx, encoder_option);
// 5. Build filter graph
FilterGraph filter_graph = get_audio_filter_graph(
src_fmt, src_sample_rate, channel_layout, enc_fmt, codec_ctx->frame_size);
// 6. Instantiate source frame
AVFramePtr src_frame = get_audio_frame(
src_fmt,
src_sample_rate,
src_num_channels,
channel_layout,
codec_ctx->frame_size);
// 7. Instantiate Converter
TensorConverter converter{
AVMEDIA_TYPE_AUDIO, src_frame, src_frame->nb_samples};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
// If anything after this throws, it will leave the StreamWriter in an
// invalid state.
Encoder enc{format_ctx, codec_ctx, get_stream(format_ctx, codec_ctx)};
return EncodeProcess{
std::move(converter),
std::move(src_frame),
std::move(filter_graph),
std::move(enc),
std::move(codec_ctx)};
}
EncodeProcess get_video_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
double frame_rate, double frame_rate,
int width, int src_width,
int height, int src_height,
const enum AVPixelFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) const c10::optional<EncodingConfig>& config) {
: codec_ctx(get_video_codec( // 1. Checkc the source format, rate and resolution
format_ctx->oformat, const AVPixelFormat src_fmt = get_pix_fmt(format);
frame_rate, AVRational src_rate = av_d2q(frame_rate, 1 << 24);
width, TORCH_CHECK(
height, src_rate.num > 0 && src_rate.den != 0,
encoder, "Frame rate must be positive and finite. Found: ",
encoder_option, frame_rate);
encoder_format, TORCH_CHECK(src_width > 0, "width must be positive. Found: ", src_width);
hw_accel, TORCH_CHECK(src_height > 0, "height must be positive. Found: ", src_height);
config)),
encoder(format_ctx, codec_ctx),
filter(get_video_filter(format, codec_ctx)),
src_frame(get_video_frame(format, codec_ctx)),
converter(AVMEDIA_TYPE_VIDEO, src_frame) {}
void EncodeProcess::process( // 2. Fetch codec from default or override
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts) {
TORCH_CHECK( TORCH_CHECK(
codec_ctx->codec_type == type, format_ctx->oformat->video_codec != AV_CODEC_ID_NONE,
"Attempted to write ", format_ctx->oformat->name,
av_get_media_type_string(type), " does not support video.");
" to ", const AVCodec* codec = get_codec(format_ctx->oformat->video_codec, encoder);
av_get_media_type_string(codec_ctx->codec_type),
" stream."); // 3. Check that encoding format, rate
if (pts) { const AVPixelFormat enc_fmt = get_enc_fmt(src_fmt, encoder_format, codec);
AVRational tb = codec_ctx->time_base; validate_frame_rate(src_rate, codec);
auto val = static_cast<int64_t>(std::round(pts.value() * tb.den / tb.num));
if (src_frame->pts > val) { // 4. Initialize codec context
TORCH_WARN_ONCE( AVCodecContextPtr codec_ctx =
"The provided PTS value is smaller than the next expected value."); get_codec_ctx(codec, format_ctx->oformat->flags);
} configure_video_codec_ctx(
src_frame->pts = val; codec_ctx, enc_fmt, src_rate, src_width, src_height, config);
} if (hw_accel) {
for (const auto& frame : converter.convert(tensor)) { #ifdef USE_CUDA
process_frame(frame); configure_hw_accel(codec_ctx, hw_accel.value());
frame->pts += frame->nb_samples; #else
TORCH_CHECK(
false,
"torchaudio is not compiled with CUDA support. ",
"Hardware acceleration is not available.");
#endif
} }
} open_codec(codec_ctx, encoder_option);
void EncodeProcess::process_frame(AVFrame* src) { // 5. Build filter graph
int ret = filter.add_frame(src); FilterGraph filter_graph = get_video_filter_graph(
while (ret >= 0) { src_fmt, src_rate, src_width, src_height, enc_fmt, hw_accel.has_value());
ret = filter.get_frame(dst_frame);
if (ret == AVERROR(EAGAIN) || ret == AVERROR_EOF) { // 6. Instantiate source frame
if (ret == AVERROR_EOF) { AVFramePtr src_frame = [&]() {
encoder.encode(nullptr); if (codec_ctx->hw_frames_ctx) {
} AVFramePtr frame{};
break; int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
} TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
if (ret >= 0) { return frame;
encoder.encode(dst_frame);
}
av_frame_unref(dst_frame);
} }
} return get_video_frame(src_fmt, src_width, src_height);
}();
void EncodeProcess::flush() { // 7. Converter
process_frame(nullptr); TensorConverter converter{AVMEDIA_TYPE_VIDEO, src_frame};
// 8. encoder
// Note: get_stream modifies AVFormatContext and adds new stream.
// If anything after this throws, it will leave the StreamWriter in an
// invalid state.
Encoder enc{format_ctx, codec_ctx, get_stream(format_ctx, codec_ctx)};
return EncodeProcess{
std::move(converter),
std::move(src_frame),
std::move(filter_graph),
std::move(enc),
std::move(codec_ctx)};
} }
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -9,47 +9,50 @@ ...@@ -9,47 +9,50 @@
namespace torchaudio::io { namespace torchaudio::io {
class EncodeProcess { class EncodeProcess {
// In the reverse order of the process
AVCodecContextPtr codec_ctx;
Encoder encoder;
AVFramePtr dst_frame{};
FilterGraph filter;
AVFramePtr src_frame;
TensorConverter converter; TensorConverter converter;
AVFramePtr src_frame;
FilterGraph filter;
AVFramePtr dst_frame{};
Encoder encoder;
AVCodecContextPtr codec_ctx;
public: public:
// Constructor for audio
EncodeProcess( EncodeProcess(
TensorConverter&& converter,
AVFramePtr&& frame,
FilterGraph&& filter_graph,
Encoder&& encoder,
AVCodecContextPtr&& codec_ctx) noexcept;
EncodeProcess(EncodeProcess&&) noexcept = default;
void process(const torch::Tensor& tensor, const c10::optional<double>& pts);
void process_frame(AVFrame* src);
void flush();
};
EncodeProcess get_audio_encode_process(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
int sample_rate, int sample_rate,
int num_channels, int num_channels,
const enum AVSampleFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config); const c10::optional<EncodingConfig>& config);
// constructor for video EncodeProcess get_video_encode_process(
EncodeProcess(
AVFormatContext* format_ctx, AVFormatContext* format_ctx,
double frame_rate, double frame_rate,
int width, int width,
int height, int height,
const enum AVPixelFormat format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config); const c10::optional<EncodingConfig>& config);
void process(
AVMediaType type,
const torch::Tensor& tensor,
const c10::optional<double>& pts);
void process_frame(AVFrame* src);
void flush();
};
}; // namespace torchaudio::io }; // namespace torchaudio::io
...@@ -2,28 +2,11 @@ ...@@ -2,28 +2,11 @@
namespace torchaudio::io { namespace torchaudio::io {
namespace { Encoder::Encoder(
AVFormatContext* format_ctx,
AVStream* add_stream(AVFormatContext* format_ctx, AVCodecContext* codec_ctx) { AVCodecContext* codec_ctx,
AVStream* stream = avformat_new_stream(format_ctx, nullptr); AVStream* stream) noexcept
TORCH_CHECK(stream, "Failed to allocate stream."); : format_ctx(format_ctx), codec_ctx(codec_ctx), stream(stream) {}
stream->time_base = codec_ctx->time_base;
int ret = avcodec_parameters_from_context(stream->codecpar, codec_ctx);
TORCH_CHECK(
ret >= 0,
"Failed to copy the stream parameter. (",
av_err2string(ret),
")");
return stream;
}
} // namespace
Encoder::Encoder(AVFormatContext* format_ctx_, AVCodecContext* codec_ctx_)
: format_ctx(format_ctx_),
codec_ctx(codec_ctx_),
stream(add_stream(format_ctx, codec_ctx)) {}
/// ///
/// Encode the given AVFrame data /// Encode the given AVFrame data
......
...@@ -12,16 +12,17 @@ class Encoder { ...@@ -12,16 +12,17 @@ class Encoder {
AVFormatContext* format_ctx; AVFormatContext* format_ctx;
// Reference to codec context (encoder) // Reference to codec context (encoder)
AVCodecContext* codec_ctx; AVCodecContext* codec_ctx;
// Stream object // Stream object as reference. Owned by AVFormatContext.
// Encoder object creates AVStream, but it will be deallocated along with
// AVFormatContext, So Encoder does not own it.
AVStream* stream; AVStream* stream;
// Temporary object used during the encoding // Temporary object used during the encoding
// Encoder owns it. // Encoder owns it.
AVPacketPtr packet{}; AVPacketPtr packet{};
public: public:
Encoder(AVFormatContext* format_ctx, AVCodecContext* codec_ctx); Encoder(
AVFormatContext* format_ctx,
AVCodecContext* codec_ctx,
AVStream* stream) noexcept;
void encode(AVFrame* frame); void encode(AVFrame* frame);
}; };
......
...@@ -53,88 +53,54 @@ StreamWriter::StreamWriter( ...@@ -53,88 +53,54 @@ StreamWriter::StreamWriter(
const c10::optional<std::string>& format) const c10::optional<std::string>& format)
: StreamWriter(get_output_format_context(dst, format, nullptr)) {} : StreamWriter(get_output_format_context(dst, format, nullptr)) {}
namespace {
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
TORCH_CHECK(fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", src);
TORCH_CHECK(
!av_sample_fmt_is_planar(fmt),
"Unexpected sample fotmat value. Valid values are ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_U8),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S16),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S32),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_S64),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_FLT),
", ",
av_get_sample_fmt_name(AV_SAMPLE_FMT_DBL),
". ",
"Found: ",
src);
return fmt;
}
enum AVPixelFormat get_src_pixel_fmt(const std::string& src) {
auto fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
case AV_PIX_FMT_YUV444P:
return fmt;
case AV_PIX_FMT_NONE:
TORCH_CHECK(false, "Unknown pixel format: ", src);
default:
TORCH_CHECK(false, "Unsupported pixel format: ", src);
}
}
} // namespace
void StreamWriter::add_audio_stream( void StreamWriter::add_audio_stream(
int64_t sample_rate, int sample_rate,
int64_t num_channels, int num_channels,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<EncodingConfig>& config) { const c10::optional<EncodingConfig>& config) {
processes.emplace_back( TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_audio_encode_process(
pFormatContext, pFormatContext,
sample_rate, sample_rate,
num_channels, num_channels,
get_src_sample_fmt(format), format,
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
config); config));
} }
void StreamWriter::add_video_stream( void StreamWriter::add_video_stream(
double frame_rate, double frame_rate,
int64_t width, int width,
int64_t height, int height,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format, const c10::optional<std::string>& encoder_format,
const c10::optional<std::string>& hw_accel, const c10::optional<std::string>& hw_accel,
const c10::optional<EncodingConfig>& config) { const c10::optional<EncodingConfig>& config) {
processes.emplace_back( TORCH_CHECK(!is_open, "Output is already opened. Cannot add a new stream.");
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
processes.emplace_back(get_video_encode_process(
pFormatContext, pFormatContext,
frame_rate, frame_rate,
width, width,
height, height,
get_src_pixel_fmt(format), format,
encoder, encoder,
encoder_option, encoder_option,
encoder_format, encoder_format,
hw_accel, hw_accel,
config); config));
} }
void StreamWriter::set_metadata(const OptionDict& metadata) { void StreamWriter::set_metadata(const OptionDict& metadata) {
...@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) { ...@@ -149,6 +115,10 @@ void StreamWriter::dump_format(int64_t i) {
} }
void StreamWriter::open(const c10::optional<OptionDict>& option) { void StreamWriter::open(const c10::optional<OptionDict>& option) {
TORCH_INTERNAL_ASSERT(
pFormatContext->nb_streams == processes.size(),
"The number of encode process and the number of output streams do not match.");
int ret = 0; int ret = 0;
// Open the file if it was not provided by client code (i.e. when not // Open the file if it was not provided by client code (i.e. when not
...@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk( ...@@ -210,12 +180,17 @@ void StreamWriter::write_audio_chunk(
const c10::optional<double>& pts) { const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?"); TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
processes.size(), pFormatContext->nb_streams,
"). Found: ", "). Found: ",
i); i);
processes[i].process(AVMEDIA_TYPE_AUDIO, waveform, pts); TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_AUDIO,
"Stream ",
i,
" is not audio type.");
processes[i].process(waveform, pts);
} }
void StreamWriter::write_video_chunk( void StreamWriter::write_video_chunk(
...@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk( ...@@ -224,12 +199,17 @@ void StreamWriter::write_video_chunk(
const c10::optional<double>& pts) { const c10::optional<double>& pts) {
TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?"); TORCH_CHECK(is_open, "Output is not opened. Did you call `open` method?");
TORCH_CHECK( TORCH_CHECK(
0 <= i && i < static_cast<int>(processes.size()), 0 <= i && i < static_cast<int>(pFormatContext->nb_streams),
"Invalid stream index. Index must be in range of [0, ", "Invalid stream index. Index must be in range of [0, ",
processes.size(), pFormatContext->nb_streams,
"). Found: ", "). Found: ",
i); i);
processes[i].process(AVMEDIA_TYPE_VIDEO, frames, pts); TORCH_CHECK(
pFormatContext->streams[i]->codecpar->codec_type == AVMEDIA_TYPE_VIDEO,
"Stream ",
i,
" is not video type.");
processes[i].process(frames, pts);
} }
void StreamWriter::flush() { void StreamWriter::flush() {
......
...@@ -100,8 +100,8 @@ class StreamWriter { ...@@ -100,8 +100,8 @@ class StreamWriter {
/// To list supported formats for the encoder, you can use /// To list supported formats for the encoder, you can use
/// ``ffmpeg -h encoder=<ENCODER>`` command. /// ``ffmpeg -h encoder=<ENCODER>`` command.
void add_audio_stream( void add_audio_stream(
int64_t sample_rate, int sample_rate,
int64_t num_channels, int num_channels,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
...@@ -139,8 +139,8 @@ class StreamWriter { ...@@ -139,8 +139,8 @@ class StreamWriter {
/// @endparblock /// @endparblock
void add_video_stream( void add_video_stream(
double frame_rate, double frame_rate,
int64_t width, int width,
int64_t height, int height,
const std::string& format, const std::string& format,
const c10::optional<std::string>& encoder, const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option, const c10::optional<OptionDict>& encoder_option,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment