"tests/vscode:/vscode.git/clone" did not exist on "5c60f33c11f9f5deed8a536a2cb0ad526da72034"
Unverified Commit 61f20323 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Hack to improve performance of resize op with nearest mode on 2D (#6661)

* Hack to improve performance of resize op with nearest mode on 2D

* Moved hack to prototype

* Moved hack into proto and reused code from stable resize

* updates

* More updates
parent 30b879fc
......@@ -14,7 +14,12 @@ from torchvision.transforms.functional import (
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import _parse_pad_padding
from torchvision.transforms.functional_tensor import (
_cast_squeeze_in,
_cast_squeeze_out,
_parse_pad_padding,
interpolate,
)
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
......@@ -104,12 +109,34 @@ def resize_image_tensor(
extra_dims = image.shape[:-3]
if image.numel() > 0:
image = _FT.resize(
image.view(-1, num_channels, old_height, old_width),
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
image = image.view(-1, num_channels, old_height, old_width)
# This is a perf hack to avoid slow channels_last upsample code path
# Related issue: https://github.com/pytorch/pytorch/issues/83840
# We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path
if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST:
# Below code is copied from _FT.resize
# This is due to the fact that we need to apply the hack on casted image and not before
# Otherwise, image will be copied while cast to float and interpolate will work on twice more data
image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64])
shape = (image.shape[0], 2, image.shape[2], image.shape[3])
image = image.expand(shape)
image = interpolate(
image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False
)
image = image[:, 0, ...]
image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
else:
image = _FT.resize(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
return image.view(extra_dims + (num_channels, new_height, new_width))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment