Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
vision
Commits
e44bba12
Unverified
Commit
e44bba12
authored
Jun 13, 2023
by
Nicolas Hug
Committed by
GitHub
Jun 13, 2023
Browse files
Fix: don't call round() on float images for ResizeV2 (#7669)
parent
906c2e95
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
1 deletion
+13
-1
test/test_transforms_v2_functional.py
test/test_transforms_v2_functional.py
+10
-0
torchvision/transforms/v2/functional/_geometry.py
torchvision/transforms/v2/functional/_geometry.py
+3
-1
No files found.
test/test_transforms_v2_functional.py
View file @
e44bba12
...
@@ -1395,3 +1395,13 @@ def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwarg
...
@@ -1395,3 +1395,13 @@ def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwarg
assert
expected_stride
==
output_stride
,
error_msg_fn
(
""
)
assert
expected_stride
==
output_stride
,
error_msg_fn
(
""
)
else
:
else
:
assert
False
,
error_msg_fn
(
""
)
assert
False
,
error_msg_fn
(
""
)
def
test_resize_float16_no_rounding
():
# Make sure Resize() doesn't round float16 images
# Non-regression test for https://github.com/pytorch/vision/issues/7667
img
=
torch
.
randint
(
0
,
256
,
size
=
(
1
,
3
,
100
,
100
),
dtype
=
torch
.
float16
)
out
=
F
.
resize
(
img
,
size
=
(
10
,
10
))
assert
out
.
dtype
==
torch
.
float16
assert
(
out
.
round
()
-
out
).
sum
()
>
0
torchvision/transforms/v2/functional/_geometry.py
View file @
e44bba12
...
@@ -228,7 +228,9 @@ def resize_image_tensor(
...
@@ -228,7 +228,9 @@ def resize_image_tensor(
if
need_cast
:
if
need_cast
:
if
interpolation
==
InterpolationMode
.
BICUBIC
and
dtype
==
torch
.
uint8
:
if
interpolation
==
InterpolationMode
.
BICUBIC
and
dtype
==
torch
.
uint8
:
image
=
image
.
clamp_
(
min
=
0
,
max
=
255
)
image
=
image
.
clamp_
(
min
=
0
,
max
=
255
)
image
=
image
.
round_
().
to
(
dtype
=
dtype
)
if
dtype
in
(
torch
.
uint8
,
torch
.
int8
,
torch
.
int16
,
torch
.
int32
,
torch
.
int64
):
image
=
image
.
round_
()
image
=
image
.
to
(
dtype
=
dtype
)
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
return
image
.
reshape
(
shape
[:
-
3
]
+
(
num_channels
,
new_height
,
new_width
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment