"vscode:/vscode.git/clone" did not exist on "5d7e80f4132d1f66feab7bcf48dc144f5aaa3110"
Unverified Commit ab4c1022 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Reduced number of graphs for compiled resize (#8108)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 893b4abd
......@@ -188,6 +188,20 @@ def resize(
return kernel(inpt, size=size, interpolation=interpolation, max_size=max_size, antialias=antialias)
# This is an internal helper method for resize_image. We should put it here instead of keeping it
# inside resize_image due to torchscript.
# uint8 dtype support for bilinear and bicubic is limited to cpu and
# according to our benchmarks on eager, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
def _do_native_uint8_resize_on_cpu(interpolation: InterpolationMode) -> bool:
if interpolation == InterpolationMode.BILINEAR:
if torch._dynamo.is_compiling():
return True
else:
return "AVX2" in torch.backends.cpu.get_cpu_capability()
return interpolation == InterpolationMode.BICUBIC
@_register_kernel_internal(resize, torch.Tensor)
@_register_kernel_internal(resize, tv_tensors.Image)
def resize_image(
......@@ -215,21 +229,16 @@ def resize_image(
if (new_height, new_width) == (old_height, old_width):
return image
elif numel > 0:
image = image.reshape(-1, num_channels, old_height, old_width)
dtype = image.dtype
acceptable_dtypes = [torch.float32, torch.float64]
if interpolation == InterpolationMode.NEAREST or interpolation == InterpolationMode.NEAREST_EXACT:
# uint8 dtype can be included for cpu and cuda input if nearest mode
acceptable_dtypes.append(torch.uint8)
elif image.device.type == "cpu":
# uint8 dtype support for bilinear and bicubic is limited to cpu and
# according to our benchmarks, non-AVX CPUs should still prefer u8->f32->interpolate->u8 path for bilinear
if (interpolation == InterpolationMode.BILINEAR and "AVX2" in torch.backends.cpu.get_cpu_capability()) or (
interpolation == InterpolationMode.BICUBIC
):
if _do_native_uint8_resize_on_cpu(interpolation):
acceptable_dtypes.append(torch.uint8)
image = image.reshape(-1, num_channels, old_height, old_width)
strides = image.stride()
if image.is_contiguous(memory_format=torch.channels_last) and image.shape[0] == 1 and numel != strides[0]:
# There is a weird behaviour in torch core where the output tensor of `interpolate()` can be allocated as
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment