Commit d0dca115 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor YUV handling functions (#2946)

Summary:
* Split `convert_[yuv420p|nv12|nv12_cuda]` functions into allocation
and data write functions.
* Merge the `get_[interlaced|planar]_image_buffer` functions into
`get_buffer` and `get_image_buffer`.
* Disassemble `convert_XXX_image` helper functions.

Pull Request resolved: https://github.com/pytorch/audio/pull/2946

Reviewed By: nateanl

Differential Revision: D42287501

Pulled By: mthrok

fbshipit-source-id: b8dd0d52fd563a112a16887b643bf497f77dfb80
parent 276dcd69
......@@ -85,18 +85,45 @@ torch::Tensor convert_audio(AVFrame* pFrame) {
return t;
}
torch::Tensor get_interlaced_image_buffer(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
int channel = av_pix_fmt_desc_get(static_cast<AVPixelFormat>(pFrame->format))
->nb_components;
torch::Tensor get_buffer(at::IntArrayRef shape, const torch::Device& device) {
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
.device(device.type(), device.index());
return torch::empty(shape, options);
}
torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) {
auto fmt = static_cast<AVPixelFormat>(frame->format);
const AVPixFmtDescriptor* desc = [&]() {
if (fmt == AV_PIX_FMT_CUDA) {
AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data;
return av_pix_fmt_desc_get(hwctx->sw_format);
}
return av_pix_fmt_desc_get(fmt);
}();
int channels = desc->nb_components;
// Note
// AVPixFmtDescriptor::nb_components represents the number of
// color components. This is different from the number of planes.
//
// For example, YUV420P has three color components Y, U and V, but
// U and V are squashed into the same plane, so there are only
// two planes.
//
// In our application, we cannot express the bare YUV420P as a
// single tensor, so we convert it to 3 channel tensor.
// For this reason, we use nb_components for the number of channels,
// instead of the number of planes.
//
// The actual number of planes can be retrieved with
// av_pix_fmt_count_planes.
return torch::empty({1, height, width, channel}, options);
if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) {
return get_buffer({1, channels, frame->height, frame->width}, device);
}
return get_buffer({1, frame->height, frame->width, channels}, device);
}
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
......@@ -111,19 +138,6 @@ void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
}
}
torch::Tensor get_planar_image_buffer(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
int num_planes =
av_pix_fmt_count_planes(static_cast<AVPixelFormat>(pFrame->format));
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
return torch::empty({1, num_planes, height, width}, options);
}
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) {
int num_planes = static_cast<int>(frame.size(1));
int height = static_cast<int>(frame.size(2));
......@@ -141,32 +155,13 @@ void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) {
}
}
namespace {
torch::Tensor convert_interlaced_video(AVFrame* pFrame) {
torch::Tensor frame = get_interlaced_image_buffer(pFrame);
write_interlaced_image(pFrame, frame);
return frame.permute({0, 3, 1, 2});
}
torch::Tensor convert_planar_video(AVFrame* pFrame) {
torch::Tensor frame = get_planar_image_buffer(pFrame);
write_planar_image(pFrame, frame);
return frame;
}
torch::Tensor convert_yuv420p(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
void write_yuv420p(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
torch::Tensor y = torch::empty({1, 1, height, width}, options);
// Write Y plane directly
{
uint8_t* tgt = y.data_ptr<uint8_t>();
uint8_t* tgt = yuv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0];
int linesize = pFrame->linesize[0];
for (int h = 0; h < height; ++h) {
......@@ -175,7 +170,9 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) {
src += linesize;
}
}
torch::Tensor uv = torch::empty({1, 2, height / 2, width / 2}, options);
// Prepare intermediate UV plane
torch::Tensor uv = get_buffer({1, 2, height / 2, width / 2});
{
uint8_t* tgt = uv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1];
......@@ -195,26 +192,24 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) {
}
// Upsample width and height
namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate(
uv,
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width})));
return torch::cat({y, uv}, 1);
// Write to the UV plane
// yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
}
torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
void write_nv12_cpu(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
torch::Tensor y = torch::empty({1, height, width, 1}, options);
// Write Y plane directly
{
uint8_t* tgt = y.data_ptr<uint8_t>();
uint8_t* tgt = yuv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0];
int linesize = pFrame->linesize[0];
for (int h = 0; h < height; ++h) {
......@@ -223,7 +218,9 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
src += linesize;
}
}
torch::Tensor uv = torch::empty({1, height / 2, width / 2, 2}, options);
// Prepare intermediate UV plane
torch::Tensor uv = get_buffer({1, height / 2, width / 2, 2});
{
uint8_t* tgt = uv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1];
......@@ -234,31 +231,28 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
src += linesize;
}
}
// Upsample width and height
namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate(
uv.view({1, 1, height / 2, width / 2, 2}),
uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width, 2})));
torch::Tensor t = torch::cat({y, uv[0]}, -1);
return t.permute({0, 3, 1, 2}); // NCHW
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
}
#ifdef USE_CUDA
torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) {
int width = pFrame->width;
int height = pFrame->height;
void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
int width = static_cast<int>(yuv.size(3));
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCUDA)
.device_index(device.index());
torch::Tensor y = torch::empty({1, height, width, 1}, options);
// Write Y plane directly
{
uint8_t* tgt = y.data_ptr<uint8_t>();
uint8_t* tgt = yuv.data_ptr<uint8_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[0];
int linesize = pFrame->linesize[0];
TORCH_CHECK(
......@@ -273,7 +267,8 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) {
cudaMemcpyDeviceToDevice),
"Failed to copy Y plane to Cuda tensor.");
}
torch::Tensor uv = torch::empty({1, height / 2, width / 2, 2}, options);
// Preapare intermediate UV planes
torch::Tensor uv = get_buffer({1, height / 2, width / 2, 2}, yuv.device());
{
uint8_t* tgt = uv.data_ptr<uint8_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[1];
......@@ -292,22 +287,24 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) {
}
// Upsample width and height
namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate(
uv.view({1, 1, height / 2, width / 2, 2}),
uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions()
.mode(torch::kNearest)
.size(std::vector<int64_t>({height, width, 2})));
torch::Tensor t = torch::cat({y, uv[0]}, -1);
return t.permute({0, 3, 1, 2}); // NCHW
.size(std::vector<int64_t>({height, width})));
// Write to the UV plane
// yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
}
#endif
} // namespace
torch::Tensor convert_image(AVFrame* pFrame, const torch::Device& device) {
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
// ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
AVPixelFormat format = static_cast<AVPixelFormat>(pFrame->format);
AVPixelFormat format = static_cast<AVPixelFormat>(frame->format);
torch::Tensor buf = get_image_buffer(frame, device);
switch (format) {
case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24:
......@@ -315,25 +312,34 @@ torch::Tensor convert_image(AVFrame* pFrame, const torch::Device& device) {
case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA:
case AV_PIX_FMT_GRAY8:
return convert_interlaced_video(pFrame);
case AV_PIX_FMT_YUV444P:
return convert_planar_video(pFrame);
case AV_PIX_FMT_YUV420P:
return convert_yuv420p(pFrame);
case AV_PIX_FMT_NV12:
return convert_nv12_cpu(pFrame);
case AV_PIX_FMT_GRAY8: {
write_interlaced_image(frame, buf);
return buf.permute({0, 3, 1, 2});
}
case AV_PIX_FMT_YUV444P: {
write_planar_image(frame, buf);
return buf;
}
case AV_PIX_FMT_YUV420P: {
write_yuv420p(frame, buf);
return buf;
}
case AV_PIX_FMT_NV12: {
write_nv12_cpu(frame, buf);
return buf;
}
#ifdef USE_CUDA
case AV_PIX_FMT_CUDA: {
AVHWFramesContext* hwctx =
(AVHWFramesContext*)pFrame->hw_frames_ctx->data;
AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data;
AVPixelFormat sw_format = hwctx->sw_format;
// cuvid decoder (nvdec frontend of ffmpeg) only supports the following
// output formats
// https://github.com/FFmpeg/FFmpeg/blob/072101bd52f7f092ee976f4e6e41c19812ad32fd/libavcodec/cuviddec.c#L1121-L1124
switch (sw_format) {
case AV_PIX_FMT_NV12:
return convert_nv12_cuda(pFrame, device);
case AV_PIX_FMT_NV12: {
write_nv12_cuda(frame, buf);
return buf;
}
case AV_PIX_FMT_P010:
case AV_PIX_FMT_P016:
TORCH_CHECK(
......
......@@ -11,8 +11,16 @@ namespace detail {
//////////////////////////////////////////////////////////////////////////////
torch::Tensor convert_audio(AVFrame* frame);
torch::Tensor get_interlaced_image_buffer(AVFrame* pFrame);
torch::Tensor get_planar_image_buffer(AVFrame* pFrame);
torch::Tensor get_buffer(
at::IntArrayRef shape,
const torch::Device& device = torch::Device(torch::kCPU));
torch::Tensor get_image_buffer(AVFrame* pFrame, const torch::Device& device);
void write_yuv420p(AVFrame* pFrame, torch::Tensor& yuv);
void write_nv12_cpu(AVFrame* pFrame, torch::Tensor& yuv);
#ifdef USE_CUDA
void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv);
#endif
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame);
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame);
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment