Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
81c46679
Unverified
Commit
81c46679
authored
Nov 23, 2022
by
Patrick von Platen
Committed by
GitHub
Nov 23, 2022
Browse files
[Image Transformers] to_pil fix float edge cases (#20406)
* Correct type checking * up
parent
1c6309bf
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
1 deletion
+24
-1
src/transformers/image_transforms.py
src/transformers/image_transforms.py
+1
-1
tests/test_image_transforms.py
tests/test_image_transforms.py
+23
-0
No files found.
src/transformers/image_transforms.py
View file @
81c46679
...
@@ -145,7 +145,7 @@ def to_pil_image(
...
@@ -145,7 +145,7 @@ def to_pil_image(
image
=
np
.
squeeze
(
image
,
axis
=-
1
)
if
image
.
shape
[
-
1
]
==
1
else
image
image
=
np
.
squeeze
(
image
,
axis
=-
1
)
if
image
.
shape
[
-
1
]
==
1
else
image
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
do_rescale
=
isinstance
(
image
.
flat
[
0
],
float
)
if
do_rescale
is
None
else
do_rescale
do_rescale
=
isinstance
(
image
.
flat
[
0
],
(
float
,
np
.
float32
,
np
.
float64
)
)
if
do_rescale
is
None
else
do_rescale
if
do_rescale
:
if
do_rescale
:
image
=
rescale
(
image
,
255
)
image
=
rescale
(
image
,
255
)
image
=
image
.
astype
(
np
.
uint8
)
image
=
image
.
astype
(
np
.
uint8
)
...
...
tests/test_image_transforms.py
View file @
81c46679
...
@@ -61,6 +61,8 @@ class ImageTransformsTester(unittest.TestCase):
...
@@ -61,6 +61,8 @@ class ImageTransformsTester(unittest.TestCase):
[
[
(
"numpy_float_channels_first"
,
(
3
,
4
,
5
),
np
.
float32
),
(
"numpy_float_channels_first"
,
(
3
,
4
,
5
),
np
.
float32
),
(
"numpy_float_channels_last"
,
(
4
,
5
,
3
),
np
.
float32
),
(
"numpy_float_channels_last"
,
(
4
,
5
,
3
),
np
.
float32
),
(
"numpy_float_channels_first"
,
(
3
,
4
,
5
),
np
.
float64
),
(
"numpy_float_channels_last"
,
(
4
,
5
,
3
),
np
.
float64
),
(
"numpy_int_channels_first"
,
(
3
,
4
,
5
),
np
.
int32
),
(
"numpy_int_channels_first"
,
(
3
,
4
,
5
),
np
.
int32
),
(
"numpy_uint_channels_first"
,
(
3
,
4
,
5
),
np
.
uint8
),
(
"numpy_uint_channels_first"
,
(
3
,
4
,
5
),
np
.
uint8
),
]
]
...
@@ -72,6 +74,27 @@ class ImageTransformsTester(unittest.TestCase):
...
@@ -72,6 +74,27 @@ class ImageTransformsTester(unittest.TestCase):
self
.
assertIsInstance
(
pil_image
,
PIL
.
Image
.
Image
)
self
.
assertIsInstance
(
pil_image
,
PIL
.
Image
.
Image
)
self
.
assertEqual
(
pil_image
.
size
,
(
5
,
4
))
self
.
assertEqual
(
pil_image
.
size
,
(
5
,
4
))
# make sure image is correctly rescaled
self
.
assertTrue
(
np
.
abs
(
np
.
asarray
(
pil_image
)).
sum
()
>
0
)
@
parameterized
.
expand
(
[
(
"numpy_float_channels_first"
,
(
3
,
4
,
5
),
np
.
float32
),
(
"numpy_float_channels_first"
,
(
3
,
4
,
5
),
np
.
float64
),
(
"numpy_float_channels_last"
,
(
4
,
5
,
3
),
np
.
float32
),
(
"numpy_float_channels_last"
,
(
4
,
5
,
3
),
np
.
float64
),
]
)
@
require_vision
def
test_to_pil_image_from_float
(
self
,
name
,
image_shape
,
dtype
):
image
=
np
.
random
.
rand
(
*
image_shape
).
astype
(
dtype
)
pil_image
=
to_pil_image
(
image
)
self
.
assertIsInstance
(
pil_image
,
PIL
.
Image
.
Image
)
self
.
assertEqual
(
pil_image
.
size
,
(
5
,
4
))
# make sure image is correctly rescaled
self
.
assertTrue
(
np
.
abs
(
np
.
asarray
(
pil_image
)).
sum
()
>
0
)
@
require_tf
@
require_tf
def
test_to_pil_image_from_tensorflow
(
self
):
def
test_to_pil_image_from_tensorflow
(
self
):
# channels_first
# channels_first
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment