Unverified Commit f4fd1933 authored by Prabhat Roy's avatar Prabhat Roy Committed by GitHub
Browse files

Return RGB frames as output of GPU decoder (#5191)

* Return RGB frames as output of GPU decoder

* Move clamp to the conversion function

* Cleaned up a bit

* Remove utility functions from test

* Use data member width directly

* Fix linter error
parent 038828ea
......@@ -472,6 +472,7 @@ def get_extensions():
"z",
"pthread",
"dl",
"nppicc",
],
extra_compile_args=extra_compile_args,
)
......
......@@ -31,10 +31,10 @@ class TestVideoGPUDecoder:
decoder = VideoReader(full_path, device="cuda:0")
with av.open(full_path) as container:
for av_frame in container.decode(container.streams.video[0]):
av_frames = torch.tensor(av_frame.to_ndarray().flatten())
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
vision_frames = next(decoder)["data"]
mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float()))
assert mean_delta < 0.1
mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float()))
assert mean_delta < 0.75
if __name__ == "__main__":
......
#include "decoder.h"
#include <c10/util/Logging.h>
#include <nppi_color_conversion.h>
#include <cmath>
#include <cstring>
#include <unordered_map>
......@@ -138,38 +139,24 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) {
}
auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA);
torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options);
torch::Tensor decoded_frame = torch::empty({get_height(), width, 3}, options);
uint8_t* frame_ptr = decoded_frame.data_ptr<uint8_t>();
const uint8_t* const source_arr[] = {
(const uint8_t* const)source_frame,
(const uint8_t* const)(source_frame + source_pitch * ((surface_height + 1) & ~1))};
auto err = nppiNV12ToRGB_709CSC_8u_P2C3R(
source_arr,
source_pitch,
frame_ptr,
width * 3,
{(int)decoded_frame.size(1), (int)decoded_frame.size(0)});
TORCH_CHECK(
err == NPP_NO_ERROR,
"Failed to convert from NV12 to RGB. Error code:",
err);
// Copy luma plane
CUDA_MEMCPY2D m = {0};
m.srcMemoryType = CU_MEMORYTYPE_DEVICE;
m.srcDevice = source_frame;
m.srcPitch = source_pitch;
m.dstMemoryType = CU_MEMORYTYPE_DEVICE;
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr);
m.dstPitch = get_width() * bytes_per_pixel;
m.WidthInBytes = get_width() * bytes_per_pixel;
m.Height = luma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
// Copy chroma plane
// NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning
// height
m.srcDevice =
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1));
m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height);
m.Height = chroma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
if (num_chroma_planes == 2) {
m.srcDevice =
(CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2);
m.dstDevice =
(CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2);
m.Height = chroma_height;
check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__);
}
check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__);
decoded_frames.push(decoded_frame);
check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__);
......
......@@ -38,48 +38,8 @@ torch::Tensor GPUDecoder::decode() {
return frame;
}
/* Convert a tensor with data in NV12 format to a tensor with data in YUV420
* format in-place.
*/
torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) {
int width = decoder.get_width(), height = decoder.get_height();
int pitch = width;
uint8_t* frame = frameTensor.data_ptr<uint8_t>();
uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)];
// sizes of source surface plane
int sizePlaneY = pitch * height;
int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2);
int sizePlaneV = sizePlaneU;
uint8_t* uv = frame + sizePlaneY;
uint8_t* u = uv;
uint8_t* v = uv + sizePlaneU;
// split chroma from interleave to planar
for (int y = 0; y < (height + 1) / 2; y++) {
for (int x = 0; x < (width + 1) / 2; x++) {
u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2];
ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1];
}
}
if (pitch == width) {
memcpy(v, ptr, sizePlaneV * sizeof(uint8_t));
} else {
for (int i = 0; i < (height + 1) / 2; i++) {
memcpy(
v + ((pitch + 1) / 2) * i,
ptr + ((width + 1) / 2) * i,
((width + 1) / 2) * sizeof(uint8_t));
}
}
delete[] ptr;
return frameTensor;
}
TORCH_LIBRARY(torchvision, m) {
m.class_<GPUDecoder>("GPUDecoder")
.def(torch::init<std::string, int64_t>())
.def("next", &GPUDecoder::decode)
.def("reformat", &GPUDecoder::nv12_to_yuv420);
.def("next", &GPUDecoder::decode);
}
......@@ -8,7 +8,6 @@ class GPUDecoder : public torch::CustomClassHolder {
GPUDecoder(std::string, int64_t);
~GPUDecoder();
torch::Tensor decode();
torch::Tensor nv12_to_yuv420(torch::Tensor);
private:
Demuxer demuxer;
......
......@@ -210,16 +210,6 @@ class VideoReader:
print("GPU decoding only works with video stream.")
return self._c.set_current_stream(stream)
def _reformat(self, tensor, output_format: str = "yuv420"):
supported_formats = [
"yuv420",
]
if output_format not in supported_formats:
raise RuntimeError(f"{output_format} not supported, please use one of {', '.join(supported_formats)}")
if not isinstance(tensor, torch.Tensor):
raise RuntimeError("Expected tensor as input parameter!")
return self._c.reformat(tensor.cpu())
__all__ = [
"write_video",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment