Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
02b5a817
"git@developer.sourcefind.cn:OpenDAS/fairscale.git" did not exist on "2c663f5acb33290b97f50b72ab1581c6c08b2b49"
Unverified
Commit
02b5a817
authored
Oct 25, 2021
by
Nicolas Hug
Committed by
GitHub
Oct 25, 2021
Browse files
Keep 16bits png decoding private (#4732)
parent
15366c4d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
40 additions
and
23 deletions
+40
-23
test/test_image.py
test/test_image.py
+16
-7
torchvision/csrc/io/image/cpu/decode_png.cpp
torchvision/csrc/io/image/cpu/decode_png.cpp
+13
-4
torchvision/csrc/io/image/cpu/decode_png.h
torchvision/csrc/io/image/cpu/decode_png.h
+2
-1
torchvision/io/image.py
torchvision/io/image.py
+9
-11
No files found.
test/test_image.py
View file @
02b5a817
...
...
@@ -22,6 +22,7 @@ from torchvision.io.image import (
write_file
,
ImageReadMode
,
read_image
,
_read_png_16
,
)
IMAGE_ROOT
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"assets"
)
...
...
@@ -156,8 +157,21 @@ def test_decode_png(img_path, pil_mode, mode):
img_pil
=
torch
.
from_numpy
(
np
.
array
(
img
))
img_pil
=
normalize_dimensions
(
img_pil
)
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
if
"16"
in
img_path
:
# 16 bits image decoding is supported, but only as a private API
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with
pytest
.
raises
(
RuntimeError
,
match
=
"At most 8-bit PNG images are supported"
):
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
img_lpng
=
_read_png_16
(
img_path
,
mode
=
mode
)
assert
img_lpng
.
dtype
==
torch
.
int32
# PIL converts 16 bits pngs in uint8
img_lpng
=
torch
.
round
(
img_lpng
/
(
2
**
16
-
1
)
*
255
).
to
(
torch
.
uint8
)
else
:
data
=
read_file
(
img_path
)
img_lpng
=
decode_image
(
data
,
mode
=
mode
)
tol
=
0
if
pil_mode
is
None
else
1
...
...
@@ -168,11 +182,6 @@ def test_decode_png(img_path, pil_mode, mode):
# TODO: remove once fix is released in PIL. Should be > 8.3.1.
img_lpng
,
img_pil
=
img_lpng
[
0
],
img_pil
[
0
]
if
"16"
in
img_path
:
# PIL converts 16 bits pngs in uint8
assert
img_lpng
.
dtype
==
torch
.
int32
img_lpng
=
torch
.
round
(
img_lpng
/
(
2
**
16
-
1
)
*
255
).
to
(
torch
.
uint8
)
torch
.
testing
.
assert_close
(
img_lpng
,
img_pil
,
atol
=
tol
,
rtol
=
0
)
...
...
torchvision/csrc/io/image/cpu/decode_png.cpp
View file @
02b5a817
...
...
@@ -5,7 +5,10 @@ namespace vision {
namespace
image
{
#if !PNG_FOUND
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
)
{
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
,
bool
allow_16_bits
)
{
TORCH_CHECK
(
false
,
"decode_png: torchvision not compiled with libPNG support"
);
}
...
...
@@ -16,7 +19,10 @@ bool is_little_endian() {
return
*
(
uint8_t
*
)
&
x
;
}
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
)
{
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
,
bool
allow_16_bits
)
{
// Check that the input tensor dtype is uint8
TORCH_CHECK
(
data
.
dtype
()
==
torch
::
kU8
,
"Expected a torch.uint8 tensor"
);
// Check that the input tensor is 1-dimensional
...
...
@@ -77,9 +83,12 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) {
TORCH_CHECK
(
retval
==
1
,
"Could read image metadata from content."
)
}
if
(
bit_depth
>
16
)
{
auto
max_bit_depth
=
allow_16_bits
?
16
:
8
;
auto
err_msg
=
"At most "
+
std
::
to_string
(
max_bit_depth
)
+
"-bit PNG images are supported currently."
;
if
(
bit_depth
>
max_bit_depth
)
{
png_destroy_read_struct
(
&
png_ptr
,
&
info_ptr
,
nullptr
);
TORCH_CHECK
(
false
,
"At most 16-bit PNG images are supported currently."
)
TORCH_CHECK
(
false
,
err_msg
)
}
int
channels
=
png_get_channels
(
png_ptr
,
info_ptr
);
...
...
torchvision/csrc/io/image/cpu/decode_png.h
View file @
02b5a817
...
...
@@ -8,7 +8,8 @@ namespace image {
C10_EXPORT
torch
::
Tensor
decode_png
(
const
torch
::
Tensor
&
data
,
ImageReadMode
mode
=
IMAGE_READ_MODE_UNCHANGED
);
ImageReadMode
mode
=
IMAGE_READ_MODE_UNCHANGED
,
bool
allow_16_bits
=
false
);
}
// namespace image
}
// namespace vision
torchvision/io/image.py
View file @
02b5a817
...
...
@@ -61,12 +61,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
.. warning::
Should pytorch ever support the uint16 dtype natively, the dtype of the
output for 16-bits pngs will be updated from int32 to uint16.
The values of the output tensor are uint8 in [0, 255].
Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
...
...
@@ -79,7 +74,7 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output
=
torch
.
ops
.
image
.
decode_png
(
input
,
mode
.
value
)
output
=
torch
.
ops
.
image
.
decode_png
(
input
,
mode
.
value
,
False
)
return
output
...
...
@@ -193,8 +188,7 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
The values of the output tensor are uint8 in [0, 255].
Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
...
...
@@ -215,8 +209,7 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255], except for
16-bits pngs which are int32 tensors in [0, 65535].
The values of the output tensor are uint8 in [0, 255].
Args:
path (str): path of the JPEG or PNG image.
...
...
@@ -230,3 +223,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc
"""
data
=
read_file
(
path
)
return
decode_image
(
data
,
mode
)
def
_read_png_16
(
path
:
str
,
mode
:
ImageReadMode
=
ImageReadMode
.
UNCHANGED
)
->
torch
.
Tensor
:
data
=
read_file
(
path
)
return
torch
.
ops
.
image
.
decode_png
(
data
,
mode
.
value
,
True
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment