Unverified Commit f87ce881 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Support for decoding jpegs on GPU with nvjpeg (#3792)


Co-authored-by: default avatarJames Thewlis <james@unitary.ai>
parent 45002089
...@@ -61,7 +61,7 @@ include(CMakePackageConfigHelpers) ...@@ -61,7 +61,7 @@ include(CMakePackageConfigHelpers)
set(TVCPP torchvision/csrc) set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu) ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda)
if(WITH_CUDA) if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif() endif()
......
...@@ -315,8 +315,23 @@ def get_extensions(): ...@@ -315,8 +315,23 @@ def get_extensions():
image_library += [jpeg_lib] image_library += [jpeg_lib]
image_include += [jpeg_include] image_include += [jpeg_include]
# Locating nvjpeg
# Should be included in CUDA_HOME for CUDA >= 10.1, which is the minimum version we have in the CI
nvjpeg_found = (
extension is CUDAExtension and
CUDA_HOME is not None and
os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))
)
print('NVJPEG found: {0}'.format(nvjpeg_found))
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
if nvjpeg_found:
print('Building torchvision with NVJPEG image support')
image_link_flags.append('nvjpeg')
image_path = os.path.join(extensions_dir, 'io', 'image') image_path = os.path.join(extensions_dir, 'io', 'image')
image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))
if png_found or jpeg_found: if png_found or jpeg_found:
ext_modules.append(extension( ext_modules.append(extension(
......
...@@ -24,6 +24,9 @@ IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9 ...@@ -24,6 +24,9 @@ IS_PY39 = sys.version_info.major == 3 and sys.version_info.minor == 9
PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367" PY39_SEGFAULT_SKIP_MSG = "Segmentation fault with Python 3.9, see https://github.com/pytorch/vision/issues/3367"
PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG) PY39_SKIP = unittest.skipIf(IS_PY39, PY39_SEGFAULT_SKIP_MSG)
IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true' IN_CIRCLE_CI = os.getenv("CIRCLECI", False) == 'true'
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'
@contextlib.contextmanager @contextlib.contextmanager
...@@ -407,11 +410,8 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names): ...@@ -407,11 +410,8 @@ def call_args_to_kwargs_only(call_args, *callable_or_arg_names):
def cpu_and_gpu(): def cpu_and_gpu():
import pytest # noqa import pytest # noqa
# ignore CPU tests in RE as they're already covered by another contbuild
IN_RE_WORKER = os.environ.get("INSIDE_RE_WORKER") is not None
IN_FBCODE = os.environ.get("IN_FBCODE_TORCHVISION") == "1"
CUDA_NOT_AVAILABLE_MSG = 'CUDA device not available'
# ignore CPU tests in RE as they're already covered by another contbuild
devices = [] if IN_RE_WORKER else ['cpu'] devices = [] if IN_RE_WORKER else ['cpu']
if torch.cuda.is_available(): if torch.cuda.is_available():
...@@ -427,3 +427,17 @@ def cpu_and_gpu(): ...@@ -427,3 +427,17 @@ def cpu_and_gpu():
devices.append(pytest.param('cuda', marks=cuda_marks)) devices.append(pytest.param('cuda', marks=cuda_marks))
return devices return devices
def needs_cuda(test_func):
import pytest # noqa
if IN_FBCODE and not IN_RE_WORKER:
# We don't want to skip in fbcode, so we just don't collect
# TODO: slightly more robust way would be to detect if we're in a sandcastle instance
# so that the test will still be collected (and skipped) in the devvms.
return pytest.mark.dont_collect(test_func)
elif torch.cuda.is_available():
return test_func
else:
return pytest.mark.skip(reason=CUDA_NOT_AVAILABLE_MSG)(test_func)
...@@ -3,10 +3,11 @@ import io ...@@ -3,10 +3,11 @@ import io
import os import os
import unittest import unittest
import pytest
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
from common_utils import get_tmp_dir from common_utils import get_tmp_dir, needs_cuda
from torchvision.io.image import ( from torchvision.io.image import (
decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file, decode_png, decode_jpeg, encode_jpeg, write_jpeg, decode_image, read_file,
...@@ -278,5 +279,43 @@ class ImageTester(unittest.TestCase): ...@@ -278,5 +279,43 @@ class ImageTester(unittest.TestCase):
os.unlink(fpath) os.unlink(fpath)
@needs_cuda
@pytest.mark.parametrize('mode', [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB])
@pytest.mark.parametrize('img_path', get_images(IMAGE_ROOT, ".jpg"))
@pytest.mark.parametrize('scripted', (False, True))
def test_decode_jpeg_cuda(mode, img_path, scripted):
if 'cmyk' in img_path:
pytest.xfail("Decoding a CMYK jpeg isn't supported")
tester = ImageTester()
data = read_file(img_path)
img = decode_image(data, mode=mode)
f = torch.jit.script(decode_jpeg) if scripted else decode_jpeg
img_nvjpeg = f(data, mode=mode, device='cuda')
# Some difference expected between jpeg implementations
tester.assertTrue((img.float() - img_nvjpeg.cpu().float()).abs().mean() < 2)
@needs_cuda
@pytest.mark.parametrize('cuda_device', ('cuda', 'cuda:0', torch.device('cuda')))
def test_decode_jpeg_cuda_device_param(cuda_device):
"""Make sure we can pass a string or a torch.device as device param"""
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
decode_jpeg(data, device=cuda_device)
@needs_cuda
def test_decode_jpeg_cuda_errors():
data = read_file(next(get_images(IMAGE_ROOT, ".jpg")))
with pytest.raises(RuntimeError, match="Expected a non empty 1-dimensional tensor"):
decode_jpeg(data.reshape(-1, 1), device='cuda')
with pytest.raises(RuntimeError, match="input tensor must be on CPU"):
decode_jpeg(data.to('cuda'), device='cuda')
with pytest.raises(RuntimeError, match="Expected a torch.uint8 tensor"):
decode_jpeg(data.to(torch.float), device='cuda')
with pytest.raises(RuntimeError, match="Expected a cuda device"):
torch.ops.image.decode_jpeg_cuda(data, ImageReadMode.UNCHANGED.value, 'cpu')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
#include "decode_jpeg_cuda.h"
#include <ATen/ATen.h>
#if NVJPEG_FOUND
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <nvjpeg.h>
#endif
#include <string>
namespace vision {
namespace image {
#if !NVJPEG_FOUND
torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support");
}
#else
namespace {
static nvjpegHandle_t nvjpeg_handle = nullptr;
}
torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
TORCH_CHECK(
!data.is_cuda(),
"The input tensor must be on CPU when decoding with nvjpeg")
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");
TORCH_CHECK(device.is_cuda(), "Expected a cuda device")
at::cuda::CUDAGuard device_guard(device);
// Create global nvJPEG handle
std::once_flag nvjpeg_handle_creation_flag;
std::call_once(nvjpeg_handle_creation_flag, []() {
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);
if (create_status != NVJPEG_STATUS_SUCCESS) {
// Reset handle so that one can still call the function again in the
// same process if there was a failure
free(nvjpeg_handle);
nvjpeg_handle = nullptr;
}
TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}
});
// Create the jpeg state
nvjpegJpegState_t jpeg_state;
nvjpegStatus_t state_status =
nvjpegJpegStateCreate(nvjpeg_handle, &jpeg_state);
TORCH_CHECK(
state_status == NVJPEG_STATUS_SUCCESS,
"nvjpegJpegStateCreate failed: ",
state_status);
auto datap = data.data_ptr<uint8_t>();
// Get the image information
int num_channels;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];
nvjpegStatus_t info_status = nvjpegGetImageInfo(
nvjpeg_handle,
datap,
data.numel(),
&num_channels,
&subsampling,
widths,
heights);
if (info_status != NVJPEG_STATUS_SUCCESS) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
}
if (subsampling == NVJPEG_CSS_UNKNOWN) {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
}
int width = widths[0];
int height = heights[0];
nvjpegOutputFormat_t ouput_format;
int num_channels_output;
switch (mode) {
case IMAGE_READ_MODE_UNCHANGED:
num_channels_output = num_channels;
// For some reason, setting output_format to NVJPEG_OUTPUT_UNCHANGED will
// not properly decode RGB images (it's fine for grayscale), so we set
// output_format manually here
if (num_channels == 1) {
ouput_format = NVJPEG_OUTPUT_Y;
} else if (num_channels == 3) {
ouput_format = NVJPEG_OUTPUT_RGB;
} else {
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false,
"When mode is UNCHANGED, only 1 or 3 input channels are allowed.");
}
break;
case IMAGE_READ_MODE_GRAY:
ouput_format = NVJPEG_OUTPUT_Y;
num_channels_output = 1;
break;
case IMAGE_READ_MODE_RGB:
ouput_format = NVJPEG_OUTPUT_RGB;
num_channels_output = 3;
break;
default:
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG decoding on GPU");
}
auto out_tensor = torch::empty(
{int64_t(num_channels_output), int64_t(height), int64_t(width)},
torch::dtype(torch::kU8).device(device));
// nvjpegImage_t is a struct with
// - an array of pointers to each channel
// - the pitch for each channel
// which must be filled in manually
nvjpegImage_t out_image;
for (int c = 0; c < num_channels_output; c++) {
out_image.channel[c] = out_tensor[c].data_ptr<uint8_t>();
out_image.pitch[c] = width;
}
for (int c = num_channels_output; c < NVJPEG_MAX_COMPONENT; c++) {
out_image.channel[c] = nullptr;
out_image.pitch[c] = 0;
}
cudaStream_t stream = at::cuda::getCurrentCUDAStream(device.index());
nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
jpeg_state,
datap,
data.numel(),
ouput_format,
&out_image,
stream);
nvjpegJpegStateDestroy(jpeg_state);
TORCH_CHECK(
decode_status == NVJPEG_STATUS_SUCCESS,
"nvjpegDecode failed: ",
decode_status);
return out_tensor;
}
#endif // NVJPEG_FOUND
} // namespace image
} // namespace vision
#pragma once
#include <torch/types.h>
#include "../image_read_mode.h"
namespace vision {
namespace image {
C10_EXPORT torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device);
} // namespace image
} // namespace vision
...@@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators() ...@@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encode_jpeg) .op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file) .op("image::read_file", &read_file)
.op("image::write_file", &write_file) .op("image::write_file", &write_file)
.op("image::decode_image", &decode_image); .op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda);
} // namespace image } // namespace image
} // namespace vision } // namespace vision
...@@ -6,3 +6,4 @@ ...@@ -6,3 +6,4 @@
#include "cpu/encode_jpeg.h" #include "cpu/encode_jpeg.h"
#include "cpu/encode_png.h" #include "cpu/encode_png.h"
#include "cpu/read_write_file.h" #include "cpu/read_write_file.h"
#include "cuda/decode_jpeg_cuda.h"
...@@ -148,7 +148,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): ...@@ -148,7 +148,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output) write_file(filename, output)
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: str = 'cpu') -> torch.Tensor:
""" """
Decodes a JPEG image into a 3 dimensional RGB Tensor. Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
...@@ -156,16 +157,25 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG ...@@ -156,16 +157,25 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG
Args: Args:
input (Tensor[1]): a one dimensional uint8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
the raw bytes of the JPEG image. the raw bytes of the JPEG image. This tensor must be on CPU,
regardless of the ``device`` parameter.
mode (ImageReadMode): the read mode used for optionally mode (ImageReadMode): the read mode used for optionally
converting the image. Default: `ImageReadMode.UNCHANGED`. converting the image. Default: `ImageReadMode.UNCHANGED`.
See `ImageReadMode` class for more information on various See `ImageReadMode` class for more information on various
available modes. available modes.
device (str or torch.device): The device on which the decoded image will
be stored. If a cuda device is specified, the image will be decoded
with `nvjpeg <https://developer.nvidia.com/nvjpeg>`_. This is only
supported for CUDA version >= 10.1
Returns: Returns:
output (Tensor[image_channels, image_height, image_width]) output (Tensor[image_channels, image_height, image_width])
""" """
output = torch.ops.image.decode_jpeg(input, mode.value) device = torch.device(device)
if device.type == 'cuda':
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment