Commit bc61f109 authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Move OutputStream init logic and simplify interface (#3105)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3105

Refactor the construction of Audio/VideoOutputStream

Reviewed By: nateanl

Differential Revision: D43613013

fbshipit-source-id: 0e112cb1bab2658be68a368099ed00ef318ea4f1
parent 5b0580ae
......@@ -2,18 +2,62 @@
namespace torchaudio::io {
namespace {
std::unique_ptr<FilterGraph> get_audio_filter(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx) {
if (src_fmt == codec_ctx->sample_fmt) {
return {nullptr};
}
std::stringstream desc;
desc << "aformat=" << av_get_sample_fmt_name(codec_ctx->sample_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO);
p->add_audio_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->sample_rate,
codec_ctx->channel_layout);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
return p;
}
AVFramePtr get_audio_frame(
AVSampleFormat src_fmt,
AVCodecContext* codec_ctx,
int default_frame_size = 10000) {
AVFramePtr frame{};
frame->format = src_fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples =
codec_ctx->frame_size ? codec_ctx->frame_size : default_frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
} // namespace
AudioOutputStream::AudioOutputStream(
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
int64_t frame_capacity_)
AVSampleFormat src_fmt,
AVCodecContextPtr&& codec_ctx_)
: OutputStream(
format_ctx,
std::move(codec_ctx),
std::move(filter),
std::move(src_frame)),
frame_capacity(frame_capacity_) {}
codec_ctx_,
get_audio_filter(src_fmt, codec_ctx_)),
src_frame(get_audio_frame(src_fmt, codec_ctx_)),
frame_capacity(src_frame->nb_samples),
codec_ctx(std::move(codec_ctx_)) {}
namespace {
......
......@@ -4,14 +4,14 @@
namespace torchaudio::io {
struct AudioOutputStream : OutputStream {
AVFramePtr src_frame;
int64_t frame_capacity;
AVCodecContextPtr codec_ctx;
AudioOutputStream(
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
int64_t frame_capacity);
AVSampleFormat src_fmt,
AVCodecContextPtr&& codec_ctx);
void write_chunk(const torch::Tensor& waveform) override;
~AudioOutputStream() override = default;
......
......@@ -4,13 +4,11 @@ namespace torchaudio::io {
OutputStream::OutputStream(
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_,
AVFramePtr&& src_frame_)
: codec_ctx(std::move(codec_ctx_)),
AVCodecContext* codec_ctx_,
std::unique_ptr<FilterGraph>&& filter_)
: codec_ctx(codec_ctx_),
encoder(format_ctx, codec_ctx),
filter(std::move(filter_)),
src_frame(std::move(src_frame_)),
dst_frame(),
num_frames(0) {}
......
......@@ -8,14 +8,12 @@
namespace torchaudio::io {
struct OutputStream {
// Codec context
AVCodecContextPtr codec_ctx;
// Reference to codec context
AVCodecContext* codec_ctx;
// Encoder + Muxer
Encoder encoder;
// Filter for additional processing
std::unique_ptr<FilterGraph> filter;
// frame that user-provided input data is written
AVFramePtr src_frame;
// frame that output from FilterGraph is written
AVFramePtr dst_frame;
// The number of samples written so far
......@@ -23,9 +21,8 @@ struct OutputStream {
OutputStream(
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame);
AVCodecContext* codec_ctx,
std::unique_ptr<FilterGraph>&& filter);
virtual void write_chunk(const torch::Tensor& input) = 0;
void process_frame(AVFrame* src);
......
......@@ -282,47 +282,6 @@ void open_codec(
TORCH_CHECK(ret >= 0, "Failed to open codec: (", av_err2string(ret), ")");
}
AVFramePtr get_audio_frame(
enum AVSampleFormat fmt,
AVCodecContextPtr& codec_ctx,
int frame_size) {
AVFramePtr frame{};
frame->format = fmt;
frame->channel_layout = codec_ctx->channel_layout;
frame->sample_rate = codec_ctx->sample_rate;
frame->nb_samples = frame_size;
if (frame->nb_samples) {
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating an audio buffer (",
av_err2string(ret),
").");
}
return frame;
}
AVFramePtr get_hw_video_frame(AVCodecContextPtr& codec_ctx) {
AVFramePtr frame{};
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame;
}
AVFramePtr get_video_frame(
enum AVPixelFormat fmt,
AVCodecContextPtr& codec_ctx) {
AVFramePtr frame{};
frame->format = fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating a video buffer (", av_err2string(ret), ").");
return frame;
}
AVCodecContextPtr get_codec_ctx(
enum AVMediaType type,
AVFORMAT_CONST AVOutputFormat* oformat,
......@@ -368,7 +327,78 @@ AVCodecContextPtr get_codec_ctx(
return AVCodecContextPtr(ctx);
}
enum AVSampleFormat _get_src_sample_fmt(const std::string& src) {
AVCodecContextPtr get_audio_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
int64_t sample_rate,
int64_t num_channels,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_AUDIO, oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
open_codec(ctx, encoder_option);
return ctx;
}
AVCodecContextPtr get_video_codec(
AVFORMAT_CONST AVOutputFormat* oformat,
double frame_rate,
int64_t width,
int64_t height,
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format,
const torch::Device device,
AVBufferRefPtr& hw_device_ctx,
AVBufferRefPtr& hw_frame_ctx) {
AVCodecContextPtr ctx = get_codec_ctx(AVMEDIA_TYPE_VIDEO, oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format);
#ifdef USE_CUDA
if (device.type() == c10::DeviceType::CUDA) {
AVBufferRef* device_ctx = nullptr;
int ret = av_hwdevice_ctx_create(
&device_ctx,
AV_HWDEVICE_TYPE_CUDA,
std::to_string(device.index()).c_str(),
nullptr,
0);
TORCH_CHECK(
ret >= 0, "Failed to create CUDA device context: ", av_err2string(ret));
hw_device_ctx.reset(device_ctx);
AVBufferRef* frames_ref = av_hwframe_ctx_alloc(device_ctx);
TORCH_CHECK(frames_ref, "Failed to create CUDA frame context.");
hw_frame_ctx.reset(frames_ref);
AVHWFramesContext* frames_ctx = (AVHWFramesContext*)(frames_ref->data);
frames_ctx->format = AV_PIX_FMT_CUDA;
frames_ctx->sw_format = ctx->pix_fmt;
frames_ctx->width = ctx->width;
frames_ctx->height = ctx->height;
frames_ctx->initial_pool_size = 20;
ctx->sw_pix_fmt = ctx->pix_fmt;
ctx->pix_fmt = AV_PIX_FMT_CUDA;
ret = av_hwframe_ctx_init(frames_ref);
TORCH_CHECK(
ret >= 0,
"Failed to initialize CUDA frame context: ",
av_err2string(ret));
ctx->hw_frames_ctx = av_buffer_ref(frames_ref);
TORCH_CHECK(
ctx->hw_frames_ctx,
"Failed to attach CUDA frames to encoding context: ",
av_err2string(ret));
}
#endif
open_codec(ctx, encoder_option);
return ctx;
}
enum AVSampleFormat get_src_sample_fmt(const std::string& src) {
auto fmt = av_get_sample_fmt(src.c_str());
TORCH_CHECK(fmt != AV_SAMPLE_FMT_NONE, "Unknown sample format: ", src);
TORCH_CHECK(
......@@ -391,7 +421,7 @@ enum AVSampleFormat _get_src_sample_fmt(const std::string& src) {
return fmt;
}
enum AVPixelFormat _get_src_pixel_fmt(const std::string& src) {
enum AVPixelFormat get_src_pixel_fmt(const std::string& src) {
auto fmt = av_get_pix_fmt(src.c_str());
switch (fmt) {
case AV_PIX_FMT_GRAY8:
......@@ -406,35 +436,6 @@ enum AVPixelFormat _get_src_pixel_fmt(const std::string& src) {
}
}
std::unique_ptr<FilterGraph> _get_audio_filter(
enum AVSampleFormat fmt,
AVCodecContextPtr& ctx) {
std::stringstream desc;
desc << "aformat=" << av_get_sample_fmt_name(ctx->sample_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO);
p->add_audio_src(fmt, ctx->time_base, ctx->sample_rate, ctx->channel_layout);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
return p;
}
std::unique_ptr<FilterGraph> _get_video_filter(
enum AVPixelFormat fmt,
AVCodecContextPtr& ctx) {
std::stringstream desc;
desc << "format=" << av_get_pix_fmt_name(ctx->pix_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_VIDEO);
p->add_video_src(
fmt, ctx->time_base, ctx->width, ctx->height, ctx->sample_aspect_ratio);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
return p;
}
} // namespace
void StreamWriter::add_audio_stream(
......@@ -444,24 +445,16 @@ void StreamWriter::add_audio_stream(
const c10::optional<std::string>& encoder,
const c10::optional<OptionDict>& encoder_option,
const c10::optional<std::string>& encoder_format) {
AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_AUDIO, pFormatContext->oformat, encoder);
configure_audio_codec(ctx, sample_rate, num_channels, encoder_format);
open_codec(ctx, encoder_option);
enum AVSampleFormat src_fmt = _get_src_sample_fmt(format);
std::unique_ptr<FilterGraph> filter = src_fmt == ctx->sample_fmt
? std::unique_ptr<FilterGraph>(nullptr)
: _get_audio_filter(src_fmt, ctx);
static const int default_capacity = 10000;
int frame_capacity = ctx->frame_size ? ctx->frame_size : default_capacity;
AVFramePtr src_frame = get_audio_frame(src_fmt, ctx, frame_capacity);
streams.emplace_back(std::make_unique<AudioOutputStream>(
pFormatContext,
std::move(ctx),
std::move(filter),
std::move(src_frame),
frame_capacity));
get_src_sample_fmt(format),
get_audio_codec(
pFormatContext->oformat,
sample_rate,
num_channels,
encoder,
encoder_option,
encoder_format)));
}
void StreamWriter::add_video_stream(
......@@ -491,75 +484,28 @@ void StreamWriter::add_video_stream(
#endif
}();
AVCodecContextPtr ctx =
get_codec_ctx(AVMEDIA_TYPE_VIDEO, pFormatContext->oformat, encoder);
configure_video_codec(ctx, frame_rate, width, height, encoder_format);
AVBufferRefPtr hw_device_ctx{};
AVBufferRefPtr hw_frame_ctx{};
#ifdef USE_CUDA
if (device.type() == c10::DeviceType::CUDA) {
AVBufferRef* device_ctx = nullptr;
int ret = av_hwdevice_ctx_create(
&device_ctx,
AV_HWDEVICE_TYPE_CUDA,
std::to_string(device.index()).c_str(),
nullptr,
0);
TORCH_CHECK(
ret >= 0, "Failed to create CUDA device context: ", av_err2string(ret));
hw_device_ctx.reset(device_ctx);
AVBufferRef* frames_ref = av_hwframe_ctx_alloc(device_ctx);
TORCH_CHECK(frames_ref, "Failed to create CUDA frame context.");
hw_frame_ctx.reset(frames_ref);
AVCodecContextPtr ctx = get_video_codec(
pFormatContext->oformat,
frame_rate,
width,
height,
encoder,
encoder_option,
encoder_format,
device,
hw_device_ctx,
hw_frame_ctx);
AVHWFramesContext* frames_ctx = (AVHWFramesContext*)(frames_ref->data);
frames_ctx->format = AV_PIX_FMT_CUDA;
frames_ctx->sw_format = ctx->pix_fmt;
frames_ctx->width = ctx->width;
frames_ctx->height = ctx->height;
frames_ctx->initial_pool_size = 20;
ctx->sw_pix_fmt = ctx->pix_fmt;
ctx->pix_fmt = AV_PIX_FMT_CUDA;
ret = av_hwframe_ctx_init(frames_ref);
TORCH_CHECK(
ret >= 0,
"Failed to initialize CUDA frame context: ",
av_err2string(ret));
ctx->hw_frames_ctx = av_buffer_ref(frames_ref);
TORCH_CHECK(
ctx->hw_frames_ctx,
"Failed to attach CUDA frames to encoding context: ",
av_err2string(ret));
}
#endif
open_codec(ctx, encoder_option);
enum AVPixelFormat src_fmt = _get_src_pixel_fmt(format);
std::unique_ptr<FilterGraph> filter = [&]() {
if (src_fmt != ctx->pix_fmt && device.type() == c10::DeviceType::CPU) {
return _get_video_filter(src_fmt, ctx);
}
return std::unique_ptr<FilterGraph>(nullptr);
}();
AVFramePtr src_frame = [&]() {
if (device.type() == c10::DeviceType::CUDA) {
return get_hw_video_frame(ctx);
}
return get_video_frame(src_fmt, ctx);
}();
streams.emplace_back(std::make_unique<VideoOutputStream>(
pFormatContext,
get_src_pixel_fmt(format),
std::move(ctx),
std::move(filter),
std::move(src_frame),
std::move(hw_device_ctx),
std::move(hw_frame_ctx)));
std::move(hw_frame_ctx),
device));
}
void StreamWriter::set_metadata(const OptionDict& metadata) {
......
......@@ -6,20 +6,74 @@
namespace torchaudio::io {
namespace {
std::unique_ptr<FilterGraph> get_video_filter(
AVPixelFormat src_fmt,
AVCodecContext* codec_ctx,
const torch::Device& device) {
if (src_fmt == codec_ctx->pix_fmt || device.type() != c10::DeviceType::CPU) {
return {nullptr};
}
std::stringstream desc;
desc << "format=" << av_get_pix_fmt_name(codec_ctx->pix_fmt);
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_VIDEO);
p->add_video_src(
src_fmt,
codec_ctx->time_base,
codec_ctx->width,
codec_ctx->height,
codec_ctx->sample_aspect_ratio);
p->add_sink();
p->add_process(desc.str());
p->create_filter();
return p;
}
AVFramePtr get_hw_video_frame(AVCodecContext* codec_ctx) {
AVFramePtr frame{};
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
return frame;
}
AVFramePtr get_video_frame(
AVPixelFormat src_fmt,
AVCodecContext* codec_ctx,
const torch::Device& device) {
if (device.type() == c10::DeviceType::CUDA) {
return get_hw_video_frame(codec_ctx);
}
AVFramePtr frame{};
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0, "Error allocating a video buffer (", av_err2string(ret), ").");
return frame;
}
} // namespace
VideoOutputStream::VideoOutputStream(
AVFormatContext* format_ctx,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx_,
AVBufferRefPtr&& hw_device_ctx_,
AVBufferRefPtr&& hw_frame_ctx_)
AVBufferRefPtr&& hw_frame_ctx_,
const torch::Device& device)
: OutputStream(
format_ctx,
std::move(codec_ctx),
std::move(filter),
std::move(src_frame)),
codec_ctx_,
get_video_filter(src_fmt, codec_ctx_, device)),
src_frame(get_video_frame(src_fmt, codec_ctx_, device)),
hw_device_ctx(std::move(hw_device_ctx_)),
hw_frame_ctx(std::move(hw_frame_ctx_)) {}
hw_frame_ctx(std::move(hw_frame_ctx_)),
codec_ctx(std::move(codec_ctx_)) {}
namespace {
......
......@@ -4,16 +4,19 @@
namespace torchaudio::io {
struct VideoOutputStream : OutputStream {
AVFramePtr src_frame;
AVBufferRefPtr hw_device_ctx;
AVBufferRefPtr hw_frame_ctx;
AVCodecContextPtr codec_ctx;
VideoOutputStream(
AVFormatContext* format_ctx,
AVPixelFormat src_fmt,
AVCodecContextPtr&& codec_ctx,
std::unique_ptr<FilterGraph>&& filter,
AVFramePtr&& src_frame,
AVBufferRefPtr&& hw_device_ctx,
AVBufferRefPtr&& hw_frame_ctx);
AVBufferRefPtr&& hw_frame_ctx,
const torch::Device& device);
void write_chunk(const torch::Tensor& frames) override;
~VideoOutputStream() override = default;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment