Unverified Commit e44bba12 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Fix: don't call round() on float images for ResizeV2 (#7669)

parent 906c2e95
...@@ -1395,3 +1395,13 @@ def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwarg ...@@ -1395,3 +1395,13 @@ def test_memory_format_consistency_resize_image_tensor(test_id, info, args_kwarg
assert expected_stride == output_stride, error_msg_fn("") assert expected_stride == output_stride, error_msg_fn("")
else: else:
assert False, error_msg_fn("") assert False, error_msg_fn("")
def test_resize_float16_no_rounding():
# Make sure Resize() doesn't round float16 images
# Non-regression test for https://github.com/pytorch/vision/issues/7667
img = torch.randint(0, 256, size=(1, 3, 100, 100), dtype=torch.float16)
out = F.resize(img, size=(10, 10))
assert out.dtype == torch.float16
assert (out.round() - out).sum() > 0
...@@ -228,7 +228,9 @@ def resize_image_tensor( ...@@ -228,7 +228,9 @@ def resize_image_tensor(
if need_cast: if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8: if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
image = image.clamp_(min=0, max=255) image = image.clamp_(min=0, max=255)
image = image.round_().to(dtype=dtype) if dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
image = image.round_()
image = image.to(dtype=dtype)
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment