Unverified Commit d8cec344 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Clean up and port the resize kernel in V2 (#6892)



* Ported `resize`

* Align with previous behaviour

* Update torchvision/prototype/transforms/functional/_geometry.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Moving input verification on top of method.
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent e64784cd
......@@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torch.nn.functional import interpolate
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
......@@ -115,6 +116,12 @@ def resize_image_tensor(
max_size: Optional[int] = None,
antialias: bool = False,
) -> torch.Tensor:
align_corners: Optional[bool] = None
if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
align_corners = False
elif antialias:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
shape = image.shape
num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
......@@ -122,13 +129,24 @@ def resize_image_tensor(
if image.numel() > 0:
image = image.reshape(-1, num_channels, old_height, old_width)
image = _FT.resize(
dtype = image.dtype
need_cast = dtype not in (torch.float32, torch.float64)
if need_cast:
image = image.to(dtype=torch.float32)
image = interpolate(
image,
size=[new_height, new_width],
interpolation=interpolation.value,
mode=interpolation.value,
align_corners=align_corners,
antialias=antialias,
)
if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
image = image.clamp_(min=0, max=255)
image = image.round_().to(dtype=dtype)
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
......@@ -1312,9 +1330,11 @@ def resized_crop(
def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
size = [int(size), int(size)]
s = int(size)
size = [s, s]
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = [size[0], size[0]]
s = size[0]
size = [s, s]
if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment