Unverified Commit 02b5a817 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Keep 16bits png decoding private (#4732)

parent 15366c4d
...@@ -22,6 +22,7 @@ from torchvision.io.image import ( ...@@ -22,6 +22,7 @@ from torchvision.io.image import (
write_file, write_file,
ImageReadMode, ImageReadMode,
read_image, read_image,
_read_png_16,
) )
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
...@@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode): ...@@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode):
img_pil = torch.from_numpy(np.array(img)) img_pil = torch.from_numpy(np.array(img))
img_pil = normalize_dimensions(img_pil) img_pil = normalize_dimensions(img_pil)
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode) if "16" in img_path:
# 16 bits image decoding is supported, but only as a private API
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
img_lpng = _read_png_16(img_path, mode=mode)
assert img_lpng.dtype == torch.int32
# PIL converts 16 bits pngs in uint8
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
else:
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
tol = 0 if pil_mode is None else 1 tol = 0 if pil_mode is None else 1
...@@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode): ...@@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode):
# TODO: remove once fix is released in PIL. Should be > 8.3.1. # TODO: remove once fix is released in PIL. Should be > 8.3.1.
img_lpng, img_pil = img_lpng[0], img_pil[0] img_lpng, img_pil = img_lpng[0], img_pil[0]
if "16" in img_path:
# PIL converts 16 bits pngs in uint8
assert img_lpng.dtype == torch.int32
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0) torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
......
...@@ -5,7 +5,10 @@ namespace vision { ...@@ -5,7 +5,10 @@ namespace vision {
namespace image { namespace image {
#if !PNG_FOUND #if !PNG_FOUND
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits) {
TORCH_CHECK( TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support"); false, "decode_png: torchvision not compiled with libPNG support");
} }
...@@ -16,7 +19,10 @@ bool is_little_endian() { ...@@ -16,7 +19,10 @@ bool is_little_endian() {
return *(uint8_t*)&x; return *(uint8_t*)&x;
} }
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits) {
// Check that the input tensor dtype is uint8 // Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional // Check that the input tensor is 1-dimensional
...@@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { ...@@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.") TORCH_CHECK(retval == 1, "Could read image metadata from content.")
} }
if (bit_depth > 16) { auto max_bit_depth = allow_16_bits ? 16 : 8;
auto err_msg = "At most " + std::to_string(max_bit_depth) +
"-bit PNG images are supported currently.";
if (bit_depth > max_bit_depth) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.") TORCH_CHECK(false, err_msg)
} }
int channels = png_get_channels(png_ptr, info_ptr); int channels = png_get_channels(png_ptr, info_ptr);
......
...@@ -8,7 +8,8 @@ namespace image { ...@@ -8,7 +8,8 @@ namespace image {
C10_EXPORT torch::Tensor decode_png( C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data, const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED); ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool allow_16_bits = false);
} // namespace image } // namespace image
} // namespace vision } // namespace vision
...@@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE ...@@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
""" """
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for The values of the output tensor are uint8 in [0, 255].
16-bits pngs which are int32 tensors in [0, 65535].
.. warning::
Should pytorch ever support the uint16 dtype natively, the dtype of the
output for 16-bits pngs will be updated from int32 to uint16.
Args: Args:
input (Tensor[1]): a one dimensional uint8 tensor containing input (Tensor[1]): a one dimensional uint8 tensor containing
...@@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE ...@@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
Returns: Returns:
output (Tensor[image_channels, image_height, image_width]) output (Tensor[image_channels, image_height, image_width])
""" """
output = torch.ops.image.decode_png(input, mode.value) output = torch.ops.image.decode_png(input, mode.value, False)
return output return output
...@@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN ...@@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
operation to decode the image into a 3 dimensional RGB or grayscale Tensor. operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for The values of the output tensor are uint8 in [0, 255].
16-bits pngs which are int32 tensors in [0, 65535].
Args: Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
...@@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc ...@@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
""" """
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format. Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for The values of the output tensor are uint8 in [0, 255].
16-bits pngs which are int32 tensors in [0, 65535].
Args: Args:
path (str): path of the JPEG or PNG image. path (str): path of the JPEG or PNG image.
...@@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc ...@@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
""" """
data = read_file(path) data = read_file(path)
return decode_image(data, mode) return decode_image(data, mode)
def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
data = read_file(path)
return torch.ops.image.decode_png(data, mode.value, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment