Unverified Commit 02b5a817 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Keep 16bits png decoding private (#4732)

parent 15366c4d
......@@ -22,6 +22,7 @@ from torchvision.io.image import (
write_file,
ImageReadMode,
read_image,
_read_png_16,
)
IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
......@@ -156,6 +157,19 @@ def test_decode_png(img_path, pil_mode, mode):
img_pil = torch.from_numpy(np.array(img))
img_pil = normalize_dimensions(img_pil)
if "16" in img_path:
# 16 bits image decoding is supported, but only as a private API
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
img_lpng = _read_png_16(img_path, mode=mode)
assert img_lpng.dtype == torch.int32
# PIL converts 16 bits pngs in uint8
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
else:
data = read_file(img_path)
img_lpng = decode_image(data, mode=mode)
......@@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode):
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
img_lpng, img_pil = img_lpng[0], img_pil[0]
if "16" in img_path:
# PIL converts 16 bits pngs in uint8
assert img_lpng.dtype == torch.int32
img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8)
torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0)
......
......@@ -5,7 +5,10 @@ namespace vision {
namespace image {
#if !PNG_FOUND
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits) {
TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support");
}
......@@ -16,7 +19,10 @@ bool is_little_endian() {
return *(uint8_t*)&x;
}
torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
......@@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}
if (bit_depth > 16) {
auto max_bit_depth = allow_16_bits ? 16 : 8;
auto err_msg = "At most " + std::to_string(max_bit_depth) +
"-bit PNG images are supported currently.";
if (bit_depth > max_bit_depth) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.")
TORCH_CHECK(false, err_msg)
}
int channels = png_get_channels(png_ptr, info_ptr);
......
......@@ -8,7 +8,8 @@ namespace image {
C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED);
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool allow_16_bits = false);
} // namespace image
} // namespace vision
......@@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
.. warning::
Should pytorch ever support the uint16 dtype natively, the dtype of the
output for 16-bits pngs will be updated from int32 to uint16.
The values of the output tensor are uint8 in [0, 255].
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
......@@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_png(input, mode.value)
output = torch.ops.image.decode_png(input, mode.value, False)
return output
......@@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
The values of the output tensor are uint8 in [0, 255].
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
......@@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
The values of the output tensor are uint8 in [0, 255].
Args:
path (str): path of the JPEG or PNG image.
......@@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
data = read_file(path)
return decode_image(data, mode)
def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
data = read_file(path)
return torch.ops.image.decode_png(data, mode.value, True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment