Commit 41e3b93d authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Fix HW accelerated encoder (#3140)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3140

https://github.com/pytorch/audio/pull/3120 introduced regression in GPU encoder.

This happened because previously source AVPixelFormat (expected channel order of
input tensor) and AVCodecContext (encoding format) in converter (module to copy
input tensor to buffer), even though converter does not need to konw about the
encoding format.

This commit fixes the issue and make sure that converter does not recieve
codec context.

Reviewed By: nateanl

Differential Revision: D43759162

fbshipit-source-id: f5f191cb54ecc82bd882aececdcae16921250261
parent d359f887
......@@ -164,11 +164,10 @@ torch::Tensor init_planar(const torch::Tensor& tensor) {
return tensor.contiguous();
}
std::pair<InitFunc, ConvertFunc> get_func(
enum AVPixelFormat pix_fmt,
enum AVPixelFormat sw_pix_fmt) {
using namespace std::placeholders;
if (pix_fmt == AV_PIX_FMT_CUDA) {
std::pair<InitFunc, ConvertFunc> get_func(AVFrame* buffer) {
if (buffer->hw_frames_ctx) {
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
auto sw_pix_fmt = frames_ctx->sw_format;
switch (sw_pix_fmt) {
case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: {
......@@ -195,6 +194,7 @@ std::pair<InitFunc, ConvertFunc> get_func(
}
}
auto pix_fmt = static_cast<AVPixelFormat>(buffer->format);
switch (pix_fmt) {
case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24:
......@@ -214,37 +214,17 @@ std::pair<InitFunc, ConvertFunc> get_func(
}
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVFramePtr frame{};
if (codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
frame->pts = 0;
return frame;
}
void validate_video_input(
enum AVPixelFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
if (fmt == AV_PIX_FMT_CUDA) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
fmt = ctx->sw_pix_fmt;
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
}
void validate_video_input(AVFrame* buffer, const torch::Tensor& t) {
auto fmt = [&]() -> AVPixelFormat {
if (buffer->hw_frames_ctx) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
return frames_ctx->sw_format;
} else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
return static_cast<AVPixelFormat>(buffer->format);
}
}();
auto dtype = t.dtype().toScalarType();
TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
......@@ -254,37 +234,28 @@ void validate_video_input(
// For example, YUV420P has only two planes. U and V are in the second plane.
int num_color_components = av_pix_fmt_desc_get(fmt)->nb_components;
const auto channels = t.size(1);
const auto height = t.size(2);
const auto width = t.size(3);
const auto c = t.size(1), h = t.size(2), w = t.size(3);
TORCH_CHECK(
channels == num_color_components && height == ctx->height &&
width == ctx->width,
c == num_color_components && h == buffer->height && w == buffer->width,
"Expected tensor with shape (N, ",
num_color_components,
", ",
ctx->height,
buffer->height,
", ",
ctx->width,
buffer->width,
") (NCHW format). Found ",
t.sizes());
}
} // namespace
VideoTensorConverter::VideoTensorConverter(
enum AVPixelFormat src_fmt_,
AVCodecContext* codec_ctx_)
: src_fmt(src_fmt_),
codec_ctx(codec_ctx_),
buffer(get_video_frame(src_fmt_, codec_ctx_)) {
std::tie(init_func, convert_func) = get_func(src_fmt, codec_ctx->sw_pix_fmt);
VideoTensorConverter::VideoTensorConverter(AVFrame* buf) : buffer(buf) {
std::tie(init_func, convert_func) = get_func(buffer);
}
SlicingTensorConverter VideoTensorConverter::convert(
const torch::Tensor& frames) {
validate_video_input(src_fmt, codec_ctx, frames);
return SlicingTensorConverter{init_func(frames), buffer, convert_func};
SlicingTensorConverter VideoTensorConverter::convert(const torch::Tensor& t) {
validate_video_input(buffer, t);
return SlicingTensorConverter{init_func(t), buffer, convert_func};
}
} // namespace torchaudio::io
......@@ -17,15 +17,13 @@ class VideoTensorConverter {
using InitFunc = std::function<torch::Tensor(const torch::Tensor&)>;
private:
enum AVPixelFormat src_fmt;
AVCodecContext* codec_ctx;
AVFramePtr buffer;
AVFrame* buffer;
InitFunc init_func{};
SlicingTensorConverter::ConvertFunc convert_func{};
public:
VideoTensorConverter(enum AVPixelFormat src_fmt, AVCodecContext* codec_ctx);
explicit VideoTensorConverter(AVFrame* buffer);
SlicingTensorConverter convert(const torch::Tensor& frames);
};
} // namespace torchaudio::io
......@@ -33,6 +33,27 @@ FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
return p;
}
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVFramePtr frame{};
if (codec_ctx->hw_frames_ctx) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
frame->pts = 0;
return frame;
}
} // namespace
VideoOutputStream::VideoOutputStream(
......@@ -45,7 +66,8 @@ VideoOutputStream::VideoOutputStream(
format_ctx,
codec_ctx_,
get_video_filter(src_fmt, codec_ctx_)),
converter(src_fmt, codec_ctx_),
buffer(get_video_frame(src_fmt, codec_ctx_)),
converter(buffer),
hw_device_ctx(std::move(hw_device_ctx_)),
hw_frame_ctx(std::move(hw_frame_ctx_)),
codec_ctx(std::move(codec_ctx_)) {}
......
......@@ -5,6 +5,7 @@
namespace torchaudio::io {
struct VideoOutputStream : OutputStream {
AVFramePtr buffer;
VideoTensorConverter converter;
AVBufferRefPtr hw_device_ctx;
AVBufferRefPtr hw_frame_ctx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment