"vscode:/vscode.git/clone" did not exist on "97e0ea9c6ebc454538b3fa505e1d199547b0feed"
Unverified Commit d8cec344 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Clean up and port the resize kernel in V2 (#6892)



* Ported `resize`

* Align with previous behaviour

* Update torchvision/prototype/transforms/functional/_geometry.py
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>

* Moving input verification on top of method.
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent e64784cd
...@@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Tuple, Union ...@@ -4,6 +4,7 @@ from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import interpolate
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
...@@ -115,6 +116,12 @@ def resize_image_tensor( ...@@ -115,6 +116,12 @@ def resize_image_tensor(
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
align_corners: Optional[bool] = None
if interpolation == InterpolationMode.BILINEAR or interpolation == InterpolationMode.BICUBIC:
align_corners = False
elif antialias:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
shape = image.shape shape = image.shape
num_channels, old_height, old_width = shape[-3:] num_channels, old_height, old_width = shape[-3:]
new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size) new_height, new_width = _compute_resized_output_size((old_height, old_width), size=size, max_size=max_size)
...@@ -122,13 +129,24 @@ def resize_image_tensor( ...@@ -122,13 +129,24 @@ def resize_image_tensor(
if image.numel() > 0: if image.numel() > 0:
image = image.reshape(-1, num_channels, old_height, old_width) image = image.reshape(-1, num_channels, old_height, old_width)
image = _FT.resize( dtype = image.dtype
need_cast = dtype not in (torch.float32, torch.float64)
if need_cast:
image = image.to(dtype=torch.float32)
image = interpolate(
image, image,
size=[new_height, new_width], size=[new_height, new_width],
interpolation=interpolation.value, mode=interpolation.value,
align_corners=align_corners,
antialias=antialias, antialias=antialias,
) )
if need_cast:
if interpolation == InterpolationMode.BICUBIC and dtype == torch.uint8:
image = image.clamp_(min=0, max=255)
image = image.round_().to(dtype=dtype)
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
...@@ -1312,9 +1330,11 @@ def resized_crop( ...@@ -1312,9 +1330,11 @@ def resized_crop(
def _parse_five_crop_size(size: List[int]) -> List[int]: def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
size = [int(size), int(size)] s = int(size)
size = [s, s]
elif isinstance(size, (tuple, list)) and len(size) == 1: elif isinstance(size, (tuple, list)) and len(size) == 1:
size = [size[0], size[0]] s = size[0]
size = [s, s]
if len(size) != 2: if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.") raise ValueError("Please provide only two dimensions (h, w) for size.")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment