Unverified Commit d4d36e60 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[TRANS, IMP] Add new max_size parameter to Resize (#3494)

* WIP, still needs tests and docs

* tests

* flake8

* Docs + fixed some tests

* proper error messages
parent 7d415473
......@@ -13,7 +13,7 @@ from torchvision.transforms import InterpolationMode
from common_utils import TransformsTester
from typing import Dict, List, Tuple
from typing import Dict, List, Sequence, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
......@@ -409,12 +409,16 @@ class Tester(TransformsTester):
batch_tensors = batch_tensors.to(dt)
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
for max_size in (None, 33, 40, 1000):
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
continue # unsupported, see assertRaises below
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
resized_tensor.size()[1:], resized_pil_img.size[::-1],
msg="{}, {}".format(size, interpolation)
)
if interpolation not in [NEAREST, ]:
......@@ -436,11 +440,12 @@ class Tester(TransformsTester):
else:
script_size = size
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
max_size=max_size)
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
)
# assert changed type warning
......@@ -449,6 +454,13 @@ class Tester(TransformsTester):
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2))
for img in (tensor, pil_img):
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
with self.assertRaisesRegex(ValueError, exp_msg):
F.resize(img, size=(32, 34), max_size=35)
with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"):
F.resize(img, size=32, max_size=32)
def test_resized_crop(self):
# test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity
......
......@@ -312,23 +312,30 @@ class Tester(unittest.TestCase):
img = Image.new("RGB", size=(width, height), color=127)
for osize in test_output_sizes_1:
for max_size in (None, 37, 1000):
t = transforms.Resize(osize)
t = transforms.Resize(osize, max_size=max_size)
result = t(img)
msg = "{}, {} - {}".format(height, width, osize)
msg = "{}, {} - {} - {}".format(height, width, osize, max_size)
osize = osize[0] if isinstance(osize, (list, tuple)) else osize
# If size is an int, smaller edge of the image will be matched to this number.
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
if height < width:
expected_size = (int(osize * width / height), osize) # (w, h)
self.assertEqual(result.size, expected_size, msg=msg)
exp_w, exp_h = (int(osize * width / height), osize) # (w, h)
if max_size is not None and max_size < exp_w:
exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
elif width < height:
expected_size = (osize, int(osize * height / width)) # (w, h)
self.assertEqual(result.size, expected_size, msg=msg)
exp_w, exp_h = (osize, int(osize * height / width)) # (w, h)
if max_size is not None and max_size < exp_h:
exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
else:
expected_size = (osize, osize) # (w, h)
self.assertEqual(result.size, expected_size, msg=msg)
exp_w, exp_h = (osize, osize) # (w, h)
if max_size is not None and max_size < osize:
exp_w, exp_h = max_size, max_size
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
for height, width in input_sizes:
img = Image.new("RGB", size=(width, height), color=127)
......
......@@ -7,6 +7,7 @@ from torchvision.transforms import InterpolationMode
import numpy as np
import unittest
from typing import Sequence
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
......@@ -322,32 +323,29 @@ class Tester(TransformsTester):
tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize)
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
for max_size in (None, 35, 1000):
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
continue # Not supported
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
transform = T.Resize(size=script_size, interpolation=interpolation)
transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
script_fn.save(os.path.join(tmp_dir, "t_resize.pt"))
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
def test_resized_crop(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
......
......@@ -337,7 +337,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return tensor
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR) -> Tensor:
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None) -> Tensor:
r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
......@@ -355,6 +356,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ```size` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
Returns:
PIL Image or Tensor: Resized image.
......@@ -372,9 +381,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation)
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
return F_t.resize(img, size=size, interpolation=interpolation.value)
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size)
def scale(*args, **kwargs):
......
......@@ -204,27 +204,40 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR):
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size))
if isinstance(size, int) or len(size) == 1:
if isinstance(size, Sequence):
if isinstance(size, Sequence) and len(size) == 1:
size = size[0]
if isinstance(size, int):
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
short, long = (w, h) if w <= h else (h, w)
if short == size:
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
new_short, new_long = size, int(size * long / short)
if max_size is not None:
if max_size <= size:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
return img.resize((new_w, new_h), interpolation)
else:
if max_size is not None:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return img.resize(size[::-1], interpolation)
......
......@@ -470,7 +470,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
return img
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor:
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor:
_assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)):
......@@ -484,34 +484,51 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
if isinstance(size, tuple):
size = list(size)
if isinstance(size, list) and len(size) not in [1, 2]:
if isinstance(size, list):
if len(size) not in [1, 2]:
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
w, h = _get_image_size(img)
if isinstance(size, int):
size_w, size_h = size, size
elif len(size) < 2:
size_w, size_h = size[0], size[0]
else:
size_w, size_h = size[1], size[0] # Convention (h, w)
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
short, long = (w, h) if w <= h else (h, w)
if isinstance(size, int) or len(size) < 2:
if w < h:
size_h = int(size_w * h / w)
if isinstance(size, int):
requested_new_short = size
else:
size_w = int(size_h * w / h)
requested_new_short = size[0]
if (w <= h and w == size_w) or (h <= w and h == size_h):
if short == requested_new_short:
return img
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners)
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255)
......
......@@ -241,16 +241,25 @@ class Resize(torch.nn.Module):
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ```size` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
"""
def __init__(self, size, interpolation=InterpolationMode.BILINEAR):
def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None):
super().__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size
self.max_size = max_size
# Backward compatibility with integer value
if isinstance(interpolation, int):
......@@ -270,11 +279,12 @@ class Resize(torch.nn.Module):
Returns:
PIL Image or Tensor: Rescaled image.
"""
return F.resize(img, self.size, self.interpolation)
return F.resize(img, self.size, self.interpolation, self.max_size)
def __repr__(self):
interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str)
return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format(
self.size, interpolate_str, self.max_size)
class Scale(Resize):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment