Unverified Commit 0ab50f5f authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Remove performance workaround for mask resize (#6729)

* Remove performance workaround for mask resize

* Fix linter

* bug fixes

* remove unnecessary import

* Fixing linter
parent 3f1d9f6b
......@@ -19,7 +19,6 @@ from prototype_common_utils import (
make_video_loaders,
mark_framework_limitation,
TestMark,
VALID_EXTRA_DIMS,
)
from torchvision.prototype import features
from torchvision.transforms.functional_tensor import _max_value as get_max_value
......@@ -215,16 +214,6 @@ def sample_inputs_resize_image_tensor():
):
yield ArgsKwargs(image_loader, size=[min(image_loader.image_size) + 1], interpolation=interpolation)
# We have a speed hack in place for nearest interpolation and single channel images (grayscale)
for image_loader in make_image_loaders(
sizes=["random"],
color_spaces=[features.ColorSpace.GRAY],
extra_dims=VALID_EXTRA_DIMS,
):
yield ArgsKwargs(
image_loader, size=[min(image_loader.image_size) + 1], interpolation=F.InterpolationMode.NEAREST
)
yield ArgsKwargs(make_image_loader(size=(11, 17)), size=20, max_size=25)
......
......@@ -14,12 +14,7 @@ from torchvision.transforms.functional import (
pil_to_tensor,
to_pil_image,
)
from torchvision.transforms.functional_tensor import (
_cast_squeeze_in,
_cast_squeeze_out,
_parse_pad_padding,
interpolate,
)
from torchvision.transforms.functional_tensor import _parse_pad_padding
from ._meta import (
convert_format_bounding_box,
......@@ -130,32 +125,12 @@ def resize_image_tensor(
if image.numel() > 0:
image = image.view(-1, num_channels, old_height, old_width)
# This is a perf hack to avoid slow channels_last upsample code path
# Related issue: https://github.com/pytorch/pytorch/issues/83840
# We are transforming (N, 1, H, W) into (N, 2, H, W) to force to take channels_first path
if image.shape[1] == 1 and interpolation == InterpolationMode.NEAREST:
# Below code is copied from _FT.resize
# This is due to the fact that we need to apply the hack on casted image and not before
# Otherwise, image will be copied while cast to float and interpolate will work on twice more data
image, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(image, [torch.float32, torch.float64])
shape = (image.shape[0], 2, image.shape[2], image.shape[3])
image = image.expand(shape)
image = interpolate(
image, size=[new_height, new_width], mode=interpolation.value, align_corners=None, antialias=False
)
image = image[:, 0, ...]
image = _cast_squeeze_out(image, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
else:
image = _FT.resize(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
image = _FT.resize(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
return image.view(extra_dims + (num_channels, new_height, new_width))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment