Unverified Commit aeafa912 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Refactored and modified private api for resize functional op (#6191)

* Refactored and modified private api for resize functional op

* Fixed failures

* More updates

* Fixed flake8
parent a5536de9
...@@ -394,9 +394,7 @@ class TestResize: ...@@ -394,9 +394,7 @@ class TestResize:
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size", "size",
[ [
[ [32],
32,
],
[32, 32], [32, 32],
(32, 32), (32, 32),
[34, 35], [34, 35],
...@@ -412,7 +410,7 @@ class TestResize: ...@@ -412,7 +410,7 @@ class TestResize:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt) tensor = tensor.to(dt)
if max_size is not None and len(size) != 1: if max_size is not None and len(size) != 1:
pytest.xfail("with max_size, size must be a sequence with 2 elements") pytest.skip("Size should be an int or a sequence of length 1 if max_size is specified")
transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size) transform = T.Resize(size=size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
...@@ -420,11 +418,7 @@ class TestResize: ...@@ -420,11 +418,7 @@ class TestResize:
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
def test_resize_save(self, tmpdir): def test_resize_save(self, tmpdir):
transform = T.Resize( transform = T.Resize(size=[32])
size=[
32,
]
)
s_transform = torch.jit.script(transform) s_transform = torch.jit.script(transform)
s_transform.save(os.path.join(tmpdir, "t_resize.pt")) s_transform.save(os.path.join(tmpdir, "t_resize.pt"))
...@@ -435,12 +429,8 @@ class TestResize: ...@@ -435,12 +429,8 @@ class TestResize:
"size", "size",
[ [
(32,), (32,),
[ [44],
44, [32],
],
[
32,
],
[32, 32], [32, 32],
(32, 32), (32, 32),
[44, 55], [44, 55],
......
...@@ -42,6 +42,8 @@ def resize_image_tensor( ...@@ -42,6 +42,8 @@ def resize_image_tensor(
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: use _compute_output_size to enable max_size option
max_size # ununsed right now
new_height, new_width = size new_height, new_width = size
num_channels, old_height, old_width = get_dimensions_image_tensor(image) num_channels, old_height, old_width = get_dimensions_image_tensor(image)
batch_shape = image.shape[:-3] batch_shape = image.shape[:-3]
...@@ -49,7 +51,6 @@ def resize_image_tensor( ...@@ -49,7 +51,6 @@ def resize_image_tensor(
image.reshape((-1, num_channels, old_height, old_width)), image.reshape((-1, num_channels, old_height, old_width)),
size=size, size=size,
interpolation=interpolation.value, interpolation=interpolation.value,
max_size=max_size,
antialias=antialias, antialias=antialias,
).reshape(batch_shape + (num_channels, new_height, new_width)) ).reshape(batch_shape + (num_channels, new_height, new_width))
...@@ -60,7 +61,9 @@ def resize_image_pil( ...@@ -60,7 +61,9 @@ def resize_image_pil(
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None, max_size: Optional[int] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation], max_size=max_size) # TODO: use _compute_output_size to enable max_size option
max_size # ununsed right now
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation])
def resize_segmentation_mask( def resize_segmentation_mask(
......
...@@ -360,6 +360,29 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -360,6 +360,29 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace)
def _compute_output_size(image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None) -> List[int]:
if len(size) == 1: # specified size only for the smallest edge
h, w = image_size
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
return [new_h, new_w]
def resize( def resize(
img: Tensor, img: Tensor,
size: List[int], size: List[int],
...@@ -423,13 +446,32 @@ def resize( ...@@ -423,13 +446,32 @@ def resize(
if not isinstance(interpolation, InterpolationMode): if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode") raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
_, image_height, image_width = get_dimensions(img)
if isinstance(size, int):
size = [size]
output_size = _compute_output_size((image_height, image_width), size, max_size)
if (image_height, image_width) == output_size:
return img
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
if antialias is not None and not antialias: if antialias is not None and not antialias:
warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.")
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size) return F_pil.resize(img, size=output_size, interpolation=pil_interpolation)
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size, antialias=antialias) return F_t.resize(img, size=output_size, interpolation=interpolation.value, antialias=antialias)
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor: def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
......
import numbers import numbers
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -240,46 +240,16 @@ def crop( ...@@ -240,46 +240,16 @@ def crop(
@torch.jit.unused @torch.jit.unused
def resize( def resize(
img: Image.Image, img: Image.Image,
size: Union[Sequence[int], int], size: Union[List[int], int],
interpolation: int = _pil_constants.BILINEAR, interpolation: int = _pil_constants.BILINEAR,
max_size: Optional[int] = None,
) -> Image.Image: ) -> Image.Image:
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError(f"img should be PIL Image. Got {type(img)}") raise TypeError(f"img should be PIL Image. Got {type(img)}")
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): if not (isinstance(size, list) and len(size) == 2):
raise TypeError(f"Got inappropriate size arg: {size}") raise TypeError(f"Got inappropriate size arg: {size}")
if isinstance(size, Sequence) and len(size) == 1: return img.resize(size[::-1], interpolation)
size = size[0]
if isinstance(size, int):
w, h = img.size
short, long = (w, h) if w <= h else (h, w)
new_short, new_long = size, int(size * long / short)
if max_size is not None:
if max_size <= size:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
if (w, h) == (new_w, new_h):
return img
else:
return img.resize((new_w, new_h), interpolation)
else:
if max_size is not None:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return img.resize(size[::-1], interpolation)
@torch.jit.unused @torch.jit.unused
......
...@@ -430,70 +430,25 @@ def resize( ...@@ -430,70 +430,25 @@ def resize(
img: Tensor, img: Tensor,
size: List[int], size: List[int],
interpolation: str = "bilinear", interpolation: str = "bilinear",
max_size: Optional[int] = None,
antialias: Optional[bool] = None, antialias: Optional[bool] = None,
) -> Tensor: ) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, str):
raise TypeError("Got inappropriate interpolation arg")
if interpolation not in ["nearest", "bilinear", "bicubic"]:
raise ValueError("This interpolation mode is unsupported with Tensor input")
if isinstance(size, tuple): if isinstance(size, tuple):
size = list(size) size = list(size)
if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError(
f"Size must be an int or a 1 or 2 element tuple/list, not a {len(size)} element tuple/list"
)
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
if antialias is None: if antialias is None:
antialias = False antialias = False
if antialias and interpolation not in ["bilinear", "bicubic"]: if antialias and interpolation not in ["bilinear", "bicubic"]:
raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only") raise ValueError("Antialias option is supported for bilinear and bicubic interpolation modes only")
_, h, w = get_dimensions(img)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
requested_new_short = size if isinstance(size, int) else size[0]
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
if (w, h) == (new_w, new_h):
return img
else: # specified both h and w
new_w, new_h = size[1], size[0]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings # Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners, antialias=antialias) img = interpolate(img, size=size, mode=interpolation, align_corners=align_corners, antialias=antialias)
if interpolation == "bicubic" and out_dtype == torch.uint8: if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255) img = img.clamp(min=0, max=255)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment