Commit 41e3b93d authored by Moto Hira's avatar Moto Hira Committed by Facebook GitHub Bot
Browse files

Fix HW accelerated encoder (#3140)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3140

https://github.com/pytorch/audio/pull/3120 introduced regression in GPU encoder.

This happened because previously source AVPixelFormat (expected channel order of
input tensor) and AVCodecContext (encoding format) in converter (module to copy
input tensor to buffer), even though converter does not need to konw about the
encoding format.

This commit fixes the issue and make sure that converter does not recieve
codec context.

Reviewed By: nateanl

Differential Revision: D43759162

fbshipit-source-id: f5f191cb54ecc82bd882aececdcae16921250261
parent d359f887
...@@ -164,11 +164,10 @@ torch::Tensor init_planar(const torch::Tensor& tensor) { ...@@ -164,11 +164,10 @@ torch::Tensor init_planar(const torch::Tensor& tensor) {
return tensor.contiguous(); return tensor.contiguous();
} }
std::pair<InitFunc, ConvertFunc> get_func( std::pair<InitFunc, ConvertFunc> get_func(AVFrame* buffer) {
enum AVPixelFormat pix_fmt, if (buffer->hw_frames_ctx) {
enum AVPixelFormat sw_pix_fmt) { auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
using namespace std::placeholders; auto sw_pix_fmt = frames_ctx->sw_format;
if (pix_fmt == AV_PIX_FMT_CUDA) {
switch (sw_pix_fmt) { switch (sw_pix_fmt) {
case AV_PIX_FMT_RGB0: case AV_PIX_FMT_RGB0:
case AV_PIX_FMT_BGR0: { case AV_PIX_FMT_BGR0: {
...@@ -195,6 +194,7 @@ std::pair<InitFunc, ConvertFunc> get_func( ...@@ -195,6 +194,7 @@ std::pair<InitFunc, ConvertFunc> get_func(
} }
} }
auto pix_fmt = static_cast<AVPixelFormat>(buffer->format);
switch (pix_fmt) { switch (pix_fmt) {
case AV_PIX_FMT_GRAY8: case AV_PIX_FMT_GRAY8:
case AV_PIX_FMT_RGB24: case AV_PIX_FMT_RGB24:
...@@ -214,37 +214,17 @@ std::pair<InitFunc, ConvertFunc> get_func( ...@@ -214,37 +214,17 @@ std::pair<InitFunc, ConvertFunc> get_func(
} }
} }
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { void validate_video_input(AVFrame* buffer, const torch::Tensor& t) {
AVFramePtr frame{}; auto fmt = [&]() -> AVPixelFormat {
if (codec_ctx->pix_fmt == AV_PIX_FMT_CUDA) { if (buffer->hw_frames_ctx) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
frame->pts = 0;
return frame;
}
void validate_video_input(
enum AVPixelFormat fmt,
AVCodecContext* ctx,
const torch::Tensor& t) {
if (fmt == AV_PIX_FMT_CUDA) {
TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA."); TORCH_CHECK(t.device().is_cuda(), "Input tensor has to be on CUDA.");
fmt = ctx->sw_pix_fmt; auto frames_ctx = (AVHWFramesContext*)(buffer->hw_frames_ctx->data);
return frames_ctx->sw_format;
} else { } else {
TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU."); TORCH_CHECK(t.device().is_cpu(), "Input tensor has to be on CPU.");
return static_cast<AVPixelFormat>(buffer->format);
} }
}();
auto dtype = t.dtype().toScalarType(); auto dtype = t.dtype().toScalarType();
TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type."); TORCH_CHECK(dtype == c10::ScalarType::Byte, "Expected Tensor of uint8 type.");
...@@ -254,37 +234,28 @@ void validate_video_input( ...@@ -254,37 +234,28 @@ void validate_video_input(
// For example, YUV420P has only two planes. U and V are in the second plane. // For example, YUV420P has only two planes. U and V are in the second plane.
int num_color_components = av_pix_fmt_desc_get(fmt)->nb_components; int num_color_components = av_pix_fmt_desc_get(fmt)->nb_components;
const auto channels = t.size(1); const auto c = t.size(1), h = t.size(2), w = t.size(3);
const auto height = t.size(2);
const auto width = t.size(3);
TORCH_CHECK( TORCH_CHECK(
channels == num_color_components && height == ctx->height && c == num_color_components && h == buffer->height && w == buffer->width,
width == ctx->width,
"Expected tensor with shape (N, ", "Expected tensor with shape (N, ",
num_color_components, num_color_components,
", ", ", ",
ctx->height, buffer->height,
", ", ", ",
ctx->width, buffer->width,
") (NCHW format). Found ", ") (NCHW format). Found ",
t.sizes()); t.sizes());
} }
} // namespace } // namespace
VideoTensorConverter::VideoTensorConverter( VideoTensorConverter::VideoTensorConverter(AVFrame* buf) : buffer(buf) {
enum AVPixelFormat src_fmt_, std::tie(init_func, convert_func) = get_func(buffer);
AVCodecContext* codec_ctx_)
: src_fmt(src_fmt_),
codec_ctx(codec_ctx_),
buffer(get_video_frame(src_fmt_, codec_ctx_)) {
std::tie(init_func, convert_func) = get_func(src_fmt, codec_ctx->sw_pix_fmt);
} }
SlicingTensorConverter VideoTensorConverter::convert( SlicingTensorConverter VideoTensorConverter::convert(const torch::Tensor& t) {
const torch::Tensor& frames) { validate_video_input(buffer, t);
validate_video_input(src_fmt, codec_ctx, frames); return SlicingTensorConverter{init_func(t), buffer, convert_func};
return SlicingTensorConverter{init_func(frames), buffer, convert_func};
} }
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -17,15 +17,13 @@ class VideoTensorConverter { ...@@ -17,15 +17,13 @@ class VideoTensorConverter {
using InitFunc = std::function<torch::Tensor(const torch::Tensor&)>; using InitFunc = std::function<torch::Tensor(const torch::Tensor&)>;
private: private:
enum AVPixelFormat src_fmt; AVFrame* buffer;
AVCodecContext* codec_ctx;
AVFramePtr buffer;
InitFunc init_func{}; InitFunc init_func{};
SlicingTensorConverter::ConvertFunc convert_func{}; SlicingTensorConverter::ConvertFunc convert_func{};
public: public:
VideoTensorConverter(enum AVPixelFormat src_fmt, AVCodecContext* codec_ctx); explicit VideoTensorConverter(AVFrame* buffer);
SlicingTensorConverter convert(const torch::Tensor& frames); SlicingTensorConverter convert(const torch::Tensor& frames);
}; };
} // namespace torchaudio::io } // namespace torchaudio::io
...@@ -33,6 +33,27 @@ FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) { ...@@ -33,6 +33,27 @@ FilterGraph get_video_filter(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
return p; return p;
} }
AVFramePtr get_video_frame(AVPixelFormat src_fmt, AVCodecContext* codec_ctx) {
AVFramePtr frame{};
if (codec_ctx->hw_frames_ctx) {
int ret = av_hwframe_get_buffer(codec_ctx->hw_frames_ctx, frame, 0);
TORCH_CHECK(ret >= 0, "Failed to fetch CUDA frame: ", av_err2string(ret));
} else {
frame->format = src_fmt;
frame->width = codec_ctx->width;
frame->height = codec_ctx->height;
int ret = av_frame_get_buffer(frame, 0);
TORCH_CHECK(
ret >= 0,
"Error allocating a video buffer (",
av_err2string(ret),
").");
}
frame->pts = 0;
return frame;
}
} // namespace } // namespace
VideoOutputStream::VideoOutputStream( VideoOutputStream::VideoOutputStream(
...@@ -45,7 +66,8 @@ VideoOutputStream::VideoOutputStream( ...@@ -45,7 +66,8 @@ VideoOutputStream::VideoOutputStream(
format_ctx, format_ctx,
codec_ctx_, codec_ctx_,
get_video_filter(src_fmt, codec_ctx_)), get_video_filter(src_fmt, codec_ctx_)),
converter(src_fmt, codec_ctx_), buffer(get_video_frame(src_fmt, codec_ctx_)),
converter(buffer),
hw_device_ctx(std::move(hw_device_ctx_)), hw_device_ctx(std::move(hw_device_ctx_)),
hw_frame_ctx(std::move(hw_frame_ctx_)), hw_frame_ctx(std::move(hw_frame_ctx_)),
codec_ctx(std::move(codec_ctx_)) {} codec_ctx(std::move(codec_ctx_)) {}
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
namespace torchaudio::io { namespace torchaudio::io {
struct VideoOutputStream : OutputStream { struct VideoOutputStream : OutputStream {
AVFramePtr buffer;
VideoTensorConverter converter; VideoTensorConverter converter;
AVBufferRefPtr hw_device_ctx; AVBufferRefPtr hw_device_ctx;
AVBufferRefPtr hw_frame_ctx; AVBufferRefPtr hw_frame_ctx;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment