Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
02b5a817
Unverified
Commit
02b5a817
authored
Oct 25, 2021
by
Nicolas Hug
Committed by
GitHub
Oct 25, 2021
Browse files
Keep 16bits png decoding private (#4732)
parent
15366c4d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
23 deletions
+40
-23
test/test_image.py
test/test_image.py
+16
-7
torchvision/csrc/io/image/cpu/decode_png.cpp
torchvision/csrc/io/image/cpu/decode_png.cpp
+13
-4
torchvision/csrc/io/image/cpu/decode_png.h
torchvision/csrc/io/image/cpu/decode_png.h
+2
-1
torchvision/io/image.py
torchvision/io/image.py
+9
-11
No files found.
test/test_image.py
View file @
02b5a817
...
@@ -22,6 +22,7 @@ from torchvision.io.image import (
...
@@ -22,6 +22,7 @@ from torchvision.io.image import (
write_file
,
write_file
,
ImageReadMode
,
ImageReadMode
,
read_image
,
read_image
,
_read_png_16
,
)
)
IMAGE_ROOT
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"assets"
)
IMAGE_ROOT
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"assets"
)
...
@@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode):
...
@@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode):
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
img_pil
=
normalize_dimensions
(
img_pil
)
img_pil
=
normalize_dimensions
(
img_pil
)
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
if
"16"
in
img_path
:
# 16 bits image decoding is supported, but only as a private API
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with
pytest
.
raises
(
RuntimeError
,
match
=
"At most 8-bit PNG images are supported"
):
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
img_lpng
=
_read_png_16
(
img_path
,
mode
=
mode
)
assert
img_lpng
.
dtype
==
torch
.
int32
# PIL converts 16 bits pngs in uint8
img_lpng
=
torch
.
round
(
img_lpng
/
(
2
**
16
-
1
)
*
255
).
to
(
torch
.
uint8
)
else
:
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
tol
=
0
if
pil_mode
is
None
else
1
tol
=
0
if
pil_mode
is
None
else
1
...
@@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode):
...
@@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode):
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
img_lpng
,
img_pil
=
img_lpng
[
0
],
img_pil
[
0
]
img_lpng
,
img_pil
=
img_lpng
[
0
],
img_pil
[
0
]
if
"16"
in
img_path
:
# PIL converts 16 bits pngs in uint8
assert
img_lpng
.
dtype
==
torch
.
int32
img_lpng
=
torch
.
round
(
img_lpng
/
(
2
**
16
-
1
)
*
255
).
to
(
torch
.
uint8
)
torch
.
testing
.
assert_close
(
img_lpng
,
img_pil
,
atol
=
tol
,
rtol
=
0
)
torch
.
testing
.
assert_close
(
img_lpng
,
img_pil
,
atol
=
tol
,
rtol
=
0
)
...
...
torchvision/csrc/io/image/cpu/decode_png.cpp
View file @
02b5a817
...
@@ -5,7 +5,10 @@ namespace vision {
...
@@ -5,7 +5,10 @@ namespace vision {
namespace
image
{
namespace
image
{
#if !PNG_FOUND
#if !PNG_FOUND
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
)
{
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
,
bool
allow_16_bits
)
{
TORCH_CHECK
(
TORCH_CHECK
(
false
,
"decode_png: torchvision not compiled with libPNG support"
);
false
,
"decode_png: torchvision not compiled with libPNG support"
);
}
}
...
@@ -16,7 +19,10 @@ bool is_little_endian() {
...
@@ -16,7 +19,10 @@ bool is_little_endian() {
return
*
(
uint8_t
*
)
&
x
;
return
*
(
uint8_t
*
)
&
x
;
}
}
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
)
{
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
,
bool
allow_16_bits
)
{
// Check that the input tensor dtype is uint8
// Check that the input tensor dtype is uint8
TORCH_CHECK
(
data
.
dtype
()
==
torch
::
kU8
,
"Expected a torch.uint8 tensor"
);
TORCH_CHECK
(
data
.
dtype
()
==
torch
::
kU8
,
"Expected a torch.uint8 tensor"
);
// Check that the input tensor is 1-dimensional
// Check that the input tensor is 1-dimensional
...
@@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
...
@@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK
(
retval
==
1
,
"Could read image metadata from content."
)
TORCH_CHECK
(
retval
==
1
,
"Could read image metadata from content."
)
}
}
if
(
bit_depth
>
16
)
{
auto
max_bit_depth
=
allow_16_bits
?
16
:
8
;
auto
err_msg
=
"At most "
+
std
::
to_string
(
max_bit_depth
)
+
"-bit PNG images are supported currently."
;
if
(
bit_depth
>
max_bit_depth
)
{
png_destroy_read_struct
(
&
png_ptr
,
&
info_ptr
,
nullptr
);
png_destroy_read_struct
(
&
png_ptr
,
&
info_ptr
,
nullptr
);
TORCH_CHECK
(
false
,
"At most 16-bit PNG images are supported currently."
)
TORCH_CHECK
(
false
,
err_msg
)
}
}
int
channels
=
png_get_channels
(
png_ptr
,
info_ptr
);
int
channels
=
png_get_channels
(
png_ptr
,
info_ptr
);
...
...
torchvision/csrc/io/image/cpu/decode_png.h
View file @
02b5a817
...
@@ -8,7 +8,8 @@ namespace image {
...
@@ -8,7 +8,8 @@ namespace image {
C10_EXPORT
torch
::
Tensor
decode_png
(
C10_EXPORT
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
=
IMAGE_READ_MODE_UNCHANGED
);
ImageReadMode
mode
=
IMAGE_READ_MODE_UNCHANGED
,
bool
allow_16_bits
=
false
);
}
// namespace image
}
// namespace image
}
// namespace vision
}
// namespace vision
torchvision/io/image.py
View file @
02b5a817
...
@@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
...
@@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
"""
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
The values of the output tensor are uint8 in [0, 255].
16-bits pngs which are int32 tensors in [0, 65535].
.. warning::
Should pytorch ever support the uint16 dtype natively, the dtype of the
output for 16-bits pngs will be updated from int32 to uint16.
Args:
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
input (Tensor[1]): a one dimensional uint8 tensor containing
...
@@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
...
@@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
Returns:
Returns:
output (Tensor[image_channels, image_height, image_width])
output (Tensor[image_channels, image_height, image_width])
"""
"""
output
=
torch
.
ops
.
image
.
decode_png
(
input
,
mode
.
value
)
output
=
torch
.
ops
.
image
.
decode_png
(
input
,
mode
.
value
,
False
)
return
output
return
output
...
@@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
...
@@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
The values of the output tensor are uint8 in [0, 255].
16-bits pngs which are int32 tensors in [0, 65535].
Args:
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
...
@@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
...
@@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
The values of the output tensor are uint8 in [0, 255].
16-bits pngs which are int32 tensors in [0, 65535].
Args:
Args:
path (str): path of the JPEG or PNG image.
path (str): path of the JPEG or PNG image.
...
@@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
...
@@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
"""
data
=
read_file
(
path
)
data
=
read_file
(
path
)
return
decode_image
(
data
,
mode
)
return
decode_image
(
data
,
mode
)
def
_read_png_16
(
path
:
str
,
mode
:
ImageReadMode
=
ImageReadMode
.
UNCHANGED
)
->
torch
.
Tensor
:
data
=
read_file
(
path
)
return
torch
.
ops
.
image
.
decode_png
(
data
,
mode
.
value
,
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment