"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "5c4ea00de772f9af456e68f30f830c7d7a158846"
Commit d0dca115 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Refactor YUV handling functions (#2946)

Summary:
* Split `convert_[yuv420p|nv12|nv12_cuda]` functions into allocation
and data write functions.
* Merge the `get_[interlaced|planar]_image_buffer` functions into
`get_buffer` and `get_image_buffer`.
* Disassemble `convert_XXX_image` helper functions.

Pull Request resolved: https://github.com/pytorch/audio/pull/2946

Reviewed By: nateanl

Differential Revision: D42287501

Pulled By: mthrok

fbshipit-source-id: b8dd0d52fd563a112a16887b643bf497f77dfb80
parent 276dcd69
...@@ -85,18 +85,45 @@ torch::Tensor convert_audio(AVFrame* pFrame) { ...@@ -85,18 +85,45 @@ torch::Tensor convert_audio(AVFrame* pFrame) {
return t; return t;
} }
torch::Tensor get_interlaced_image_buffer(AVFrame* pFrame) { torch::Tensor get_buffer(at::IntArrayRef shape, const torch::Device& device) {
int width = pFrame->width;
int height = pFrame->height;
int channel = av_pix_fmt_desc_get(static_cast<AVPixelFormat>(pFrame->format))
->nb_components;
auto options = torch::TensorOptions() auto options = torch::TensorOptions()
.dtype(torch::kUInt8) .dtype(torch::kUInt8)
.layout(torch::kStrided) .layout(torch::kStrided)
.device(torch::kCPU); .device(device.type(), device.index());
return torch::empty(shape, options);
}
torch::Tensor get_image_buffer(AVFrame* frame, const torch::Device& device) {
auto fmt = static_cast<AVPixelFormat>(frame->format);
const AVPixFmtDescriptor* desc = [&]() {
if (fmt == AV_PIX_FMT_CUDA) {
AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data;
return av_pix_fmt_desc_get(hwctx->sw_format);
}
return av_pix_fmt_desc_get(fmt);
}();
int channels = desc->nb_components;
// Note
// AVPixFmtDescriptor::nb_components represents the number of
// color components. This is different from the number of planes.
//
// For example, YUV420P has three color components Y, U and V, but
// U and V are squashed into the same plane, so there are only
// two planes.
//
// In our application, we cannot express the bare YUV420P as a
// single tensor, so we convert it to 3 channel tensor.
// For this reason, we use nb_components for the number of channels,
// instead of the number of planes.
//
// The actual number of planes can be retrieved with
// av_pix_fmt_count_planes.
return torch::empty({1, height, width, channel}, options); if (desc->flags & AV_PIX_FMT_FLAG_PLANAR) {
return get_buffer({1, channels, frame->height, frame->width}, device);
}
return get_buffer({1, frame->height, frame->width, channels}, device);
} }
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) { void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
...@@ -111,19 +138,6 @@ void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) { ...@@ -111,19 +138,6 @@ void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame) {
} }
} }
torch::Tensor get_planar_image_buffer(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
int num_planes =
av_pix_fmt_count_planes(static_cast<AVPixelFormat>(pFrame->format));
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
return torch::empty({1, num_planes, height, width}, options);
}
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) { void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) {
int num_planes = static_cast<int>(frame.size(1)); int num_planes = static_cast<int>(frame.size(1));
int height = static_cast<int>(frame.size(2)); int height = static_cast<int>(frame.size(2));
...@@ -141,32 +155,13 @@ void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) { ...@@ -141,32 +155,13 @@ void write_planar_image(AVFrame* pFrame, torch::Tensor& frame) {
} }
} }
namespace { void write_yuv420p(AVFrame* pFrame, torch::Tensor& yuv) {
int height = static_cast<int>(yuv.size(2));
torch::Tensor convert_interlaced_video(AVFrame* pFrame) { int width = static_cast<int>(yuv.size(3));
torch::Tensor frame = get_interlaced_image_buffer(pFrame);
write_interlaced_image(pFrame, frame);
return frame.permute({0, 3, 1, 2});
}
torch::Tensor convert_planar_video(AVFrame* pFrame) {
torch::Tensor frame = get_planar_image_buffer(pFrame);
write_planar_image(pFrame, frame);
return frame;
}
torch::Tensor convert_yuv420p(AVFrame* pFrame) {
int width = pFrame->width;
int height = pFrame->height;
auto options = torch::TensorOptions() // Write Y plane directly
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
torch::Tensor y = torch::empty({1, 1, height, width}, options);
{ {
uint8_t* tgt = y.data_ptr<uint8_t>(); uint8_t* tgt = yuv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0]; uint8_t* src = pFrame->data[0];
int linesize = pFrame->linesize[0]; int linesize = pFrame->linesize[0];
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
...@@ -175,7 +170,9 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) { ...@@ -175,7 +170,9 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) {
src += linesize; src += linesize;
} }
} }
torch::Tensor uv = torch::empty({1, 2, height / 2, width / 2}, options);
// Prepare intermediate UV plane
torch::Tensor uv = get_buffer({1, 2, height / 2, width / 2});
{ {
uint8_t* tgt = uv.data_ptr<uint8_t>(); uint8_t* tgt = uv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1]; uint8_t* src = pFrame->data[1];
...@@ -195,26 +192,24 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) { ...@@ -195,26 +192,24 @@ torch::Tensor convert_yuv420p(AVFrame* pFrame) {
} }
// Upsample width and height // Upsample width and height
namespace F = torch::nn::functional; namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate( uv = F::interpolate(
uv, uv,
F::InterpolateFuncOptions() F::InterpolateFuncOptions()
.mode(torch::kNearest) .mode(torch::kNearest)
.size(std::vector<int64_t>({height, width}))); .size(std::vector<int64_t>({height, width})));
return torch::cat({y, uv}, 1); // Write to the UV plane
// yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
} }
torch::Tensor convert_nv12_cpu(AVFrame* pFrame) { void write_nv12_cpu(AVFrame* pFrame, torch::Tensor& yuv) {
int width = pFrame->width; int height = static_cast<int>(yuv.size(2));
int height = pFrame->height; int width = static_cast<int>(yuv.size(3));
auto options = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCPU);
torch::Tensor y = torch::empty({1, height, width, 1}, options); // Write Y plane directly
{ {
uint8_t* tgt = y.data_ptr<uint8_t>(); uint8_t* tgt = yuv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[0]; uint8_t* src = pFrame->data[0];
int linesize = pFrame->linesize[0]; int linesize = pFrame->linesize[0];
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
...@@ -223,7 +218,9 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) { ...@@ -223,7 +218,9 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
src += linesize; src += linesize;
} }
} }
torch::Tensor uv = torch::empty({1, height / 2, width / 2, 2}, options);
// Prepare intermediate UV plane
torch::Tensor uv = get_buffer({1, height / 2, width / 2, 2});
{ {
uint8_t* tgt = uv.data_ptr<uint8_t>(); uint8_t* tgt = uv.data_ptr<uint8_t>();
uint8_t* src = pFrame->data[1]; uint8_t* src = pFrame->data[1];
...@@ -234,31 +231,28 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) { ...@@ -234,31 +231,28 @@ torch::Tensor convert_nv12_cpu(AVFrame* pFrame) {
src += linesize; src += linesize;
} }
} }
// Upsample width and height // Upsample width and height
namespace F = torch::nn::functional; namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate( uv = F::interpolate(
uv.view({1, 1, height / 2, width / 2, 2}), uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions() F::InterpolateFuncOptions()
.mode(torch::kNearest) .mode(torch::kNearest)
.size(std::vector<int64_t>({height, width, 2}))); .size(std::vector<int64_t>({height, width})));
torch::Tensor t = torch::cat({y, uv[0]}, -1); // Write to the UV plane
return t.permute({0, 3, 1, 2}); // NCHW // yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
} }
#ifdef USE_CUDA #ifdef USE_CUDA
torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) { void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv) {
int width = pFrame->width; int height = static_cast<int>(yuv.size(2));
int height = pFrame->height; int width = static_cast<int>(yuv.size(3));
auto options = torch::TensorOptions() // Write Y plane directly
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(torch::kCUDA)
.device_index(device.index());
torch::Tensor y = torch::empty({1, height, width, 1}, options);
{ {
uint8_t* tgt = y.data_ptr<uint8_t>(); uint8_t* tgt = yuv.data_ptr<uint8_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[0]; CUdeviceptr src = (CUdeviceptr)pFrame->data[0];
int linesize = pFrame->linesize[0]; int linesize = pFrame->linesize[0];
TORCH_CHECK( TORCH_CHECK(
...@@ -273,7 +267,8 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) { ...@@ -273,7 +267,8 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) {
cudaMemcpyDeviceToDevice), cudaMemcpyDeviceToDevice),
"Failed to copy Y plane to Cuda tensor."); "Failed to copy Y plane to Cuda tensor.");
} }
torch::Tensor uv = torch::empty({1, height / 2, width / 2, 2}, options); // Preapare intermediate UV planes
torch::Tensor uv = get_buffer({1, height / 2, width / 2, 2}, yuv.device());
{ {
uint8_t* tgt = uv.data_ptr<uint8_t>(); uint8_t* tgt = uv.data_ptr<uint8_t>();
CUdeviceptr src = (CUdeviceptr)pFrame->data[1]; CUdeviceptr src = (CUdeviceptr)pFrame->data[1];
...@@ -292,22 +287,24 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) { ...@@ -292,22 +287,24 @@ torch::Tensor convert_nv12_cuda(AVFrame* pFrame, const torch::Device& device) {
} }
// Upsample width and height // Upsample width and height
namespace F = torch::nn::functional; namespace F = torch::nn::functional;
using namespace torch::indexing;
uv = F::interpolate( uv = F::interpolate(
uv.view({1, 1, height / 2, width / 2, 2}), uv.permute({0, 3, 1, 2}),
F::InterpolateFuncOptions() F::InterpolateFuncOptions()
.mode(torch::kNearest) .mode(torch::kNearest)
.size(std::vector<int64_t>({height, width, 2}))); .size(std::vector<int64_t>({height, width})));
torch::Tensor t = torch::cat({y, uv[0]}, -1); // Write to the UV plane
return t.permute({0, 3, 1, 2}); // NCHW // yuv[:, 1:] = uv
yuv.index_put_({Slice(), Slice(1)}, uv);
} }
#endif #endif
} // namespace
torch::Tensor convert_image(AVFrame* pFrame, const torch::Device& device) { torch::Tensor convert_image(AVFrame* frame, const torch::Device& device) {
// ref: // ref:
// https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179 // https://ffmpeg.org/doxygen/4.1/filtering__video_8c_source.html#l00179
// https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038 // https://ffmpeg.org/doxygen/4.1/decode__video_8c_source.html#l00038
AVPixelFormat format = static_cast<AVPixelFormat>(pFrame->format); AVPixelFormat format = static_cast<AVPixelFormat>(frame->format);
torch::Tensor buf = get_image_buffer(frame, device);
switch (format) { switch (format) {
case AV_PIX_FMT_RGB24: case AV_PIX_FMT_RGB24:
case AV_PIX_FMT_BGR24: case AV_PIX_FMT_BGR24:
...@@ -315,25 +312,34 @@ torch::Tensor convert_image(AVFrame* pFrame, const torch::Device& device) { ...@@ -315,25 +312,34 @@ torch::Tensor convert_image(AVFrame* pFrame, const torch::Device& device) {
case AV_PIX_FMT_RGBA: case AV_PIX_FMT_RGBA:
case AV_PIX_FMT_ABGR: case AV_PIX_FMT_ABGR:
case AV_PIX_FMT_BGRA: case AV_PIX_FMT_BGRA:
case AV_PIX_FMT_GRAY8: case AV_PIX_FMT_GRAY8: {
return convert_interlaced_video(pFrame); write_interlaced_image(frame, buf);
case AV_PIX_FMT_YUV444P: return buf.permute({0, 3, 1, 2});
return convert_planar_video(pFrame); }
case AV_PIX_FMT_YUV420P: case AV_PIX_FMT_YUV444P: {
return convert_yuv420p(pFrame); write_planar_image(frame, buf);
case AV_PIX_FMT_NV12: return buf;
return convert_nv12_cpu(pFrame); }
case AV_PIX_FMT_YUV420P: {
write_yuv420p(frame, buf);
return buf;
}
case AV_PIX_FMT_NV12: {
write_nv12_cpu(frame, buf);
return buf;
}
#ifdef USE_CUDA #ifdef USE_CUDA
case AV_PIX_FMT_CUDA: { case AV_PIX_FMT_CUDA: {
AVHWFramesContext* hwctx = AVHWFramesContext* hwctx = (AVHWFramesContext*)frame->hw_frames_ctx->data;
(AVHWFramesContext*)pFrame->hw_frames_ctx->data;
AVPixelFormat sw_format = hwctx->sw_format; AVPixelFormat sw_format = hwctx->sw_format;
// cuvid decoder (nvdec frontend of ffmpeg) only supports the following // cuvid decoder (nvdec frontend of ffmpeg) only supports the following
// output formats // output formats
// https://github.com/FFmpeg/FFmpeg/blob/072101bd52f7f092ee976f4e6e41c19812ad32fd/libavcodec/cuviddec.c#L1121-L1124 // https://github.com/FFmpeg/FFmpeg/blob/072101bd52f7f092ee976f4e6e41c19812ad32fd/libavcodec/cuviddec.c#L1121-L1124
switch (sw_format) { switch (sw_format) {
case AV_PIX_FMT_NV12: case AV_PIX_FMT_NV12: {
return convert_nv12_cuda(pFrame, device); write_nv12_cuda(frame, buf);
return buf;
}
case AV_PIX_FMT_P010: case AV_PIX_FMT_P010:
case AV_PIX_FMT_P016: case AV_PIX_FMT_P016:
TORCH_CHECK( TORCH_CHECK(
......
...@@ -11,8 +11,16 @@ namespace detail { ...@@ -11,8 +11,16 @@ namespace detail {
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
torch::Tensor convert_audio(AVFrame* frame); torch::Tensor convert_audio(AVFrame* frame);
torch::Tensor get_interlaced_image_buffer(AVFrame* pFrame); torch::Tensor get_buffer(
torch::Tensor get_planar_image_buffer(AVFrame* pFrame); at::IntArrayRef shape,
const torch::Device& device = torch::Device(torch::kCPU));
torch::Tensor get_image_buffer(AVFrame* pFrame, const torch::Device& device);
void write_yuv420p(AVFrame* pFrame, torch::Tensor& yuv);
void write_nv12_cpu(AVFrame* pFrame, torch::Tensor& yuv);
#ifdef USE_CUDA
void write_nv12_cuda(AVFrame* pFrame, torch::Tensor& yuv);
#endif
void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame); void write_interlaced_image(AVFrame* pFrame, torch::Tensor& frame);
void write_planar_image(AVFrame* pFrame, torch::Tensor& frame); void write_planar_image(AVFrame* pFrame, torch::Tensor& frame);
torch::Tensor convert_image(AVFrame* frame, const torch::Device& device); torch::Tensor convert_image(AVFrame* frame, const torch::Device& device);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment