Unverified Commit d4d36e60 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

[TRANS, IMP] Add new max_size parameter to Resize (#3494)

* WIP, still needs tests and docs

* tests

* flake8

* Docs + fixed some tests

* proper error messages
parent 7d415473
...@@ -13,7 +13,7 @@ from torchvision.transforms import InterpolationMode ...@@ -13,7 +13,7 @@ from torchvision.transforms import InterpolationMode
from common_utils import TransformsTester from common_utils import TransformsTester
from typing import Dict, List, Tuple from typing import Dict, List, Sequence, Tuple
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
...@@ -409,39 +409,44 @@ class Tester(TransformsTester): ...@@ -409,39 +409,44 @@ class Tester(TransformsTester):
batch_tensors = batch_tensors.to(dt) batch_tensors = batch_tensors.to(dt)
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]: for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]: for max_size in (None, 33, 40, 1000):
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation) continue # unsupported, see assertRaises below
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
self.assertEqual( resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation) resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
)
self.assertEqual(
if interpolation not in [NEAREST, ]: resized_tensor.size()[1:], resized_pil_img.size[::-1],
# We can not check values if mode = NEAREST, as results are different msg="{}, {}".format(size, interpolation)
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)
# Pay attention to high tolerance for MAE
self.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
) )
if isinstance(size, int): if interpolation not in [NEAREST, ]:
script_size = [size, ] # We can not check values if mode = NEAREST, as results are different
else: # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
script_size = size # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)
# Pay attention to high tolerance for MAE
self.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
)
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation) if isinstance(size, int):
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation)) script_size = [size, ]
else:
script_size = size
self._test_fn_on_batch( resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
batch_tensors, F.resize, size=script_size, interpolation=interpolation max_size=max_size)
) self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
self._test_fn_on_batch(
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
)
# assert changed type warning # assert changed type warning
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"): with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
...@@ -449,6 +454,13 @@ class Tester(TransformsTester): ...@@ -449,6 +454,13 @@ class Tester(TransformsTester):
res2 = F.resize(tensor, size=32, interpolation=BILINEAR) res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
self.assertTrue(res1.equal(res2)) self.assertTrue(res1.equal(res2))
for img in (tensor, pil_img):
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
with self.assertRaisesRegex(ValueError, exp_msg):
F.resize(img, size=(32, 34), max_size=35)
with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"):
F.resize(img, size=32, max_size=32)
def test_resized_crop(self): def test_resized_crop(self):
# test values of F.resized_crop in several cases: # test values of F.resized_crop in several cases:
# 1) resize to the same size, crop to the same size => should be identity # 1) resize to the same size, crop to the same size => should be identity
......
...@@ -312,23 +312,30 @@ class Tester(unittest.TestCase): ...@@ -312,23 +312,30 @@ class Tester(unittest.TestCase):
img = Image.new("RGB", size=(width, height), color=127) img = Image.new("RGB", size=(width, height), color=127)
for osize in test_output_sizes_1: for osize in test_output_sizes_1:
for max_size in (None, 37, 1000):
t = transforms.Resize(osize)
result = t(img) t = transforms.Resize(osize, max_size=max_size)
result = t(img)
msg = "{}, {} - {}".format(height, width, osize)
osize = osize[0] if isinstance(osize, (list, tuple)) else osize msg = "{}, {} - {} - {}".format(height, width, osize, max_size)
# If size is an int, smaller edge of the image will be matched to this number. osize = osize[0] if isinstance(osize, (list, tuple)) else osize
# i.e, if height > width, then image will be rescaled to (size * height / width, size). # If size is an int, smaller edge of the image will be matched to this number.
if height < width: # i.e, if height > width, then image will be rescaled to (size * height / width, size).
expected_size = (int(osize * width / height), osize) # (w, h) if height < width:
self.assertEqual(result.size, expected_size, msg=msg) exp_w, exp_h = (int(osize * width / height), osize) # (w, h)
elif width < height: if max_size is not None and max_size < exp_w:
expected_size = (osize, int(osize * height / width)) # (w, h) exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
self.assertEqual(result.size, expected_size, msg=msg) self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
else: elif width < height:
expected_size = (osize, osize) # (w, h) exp_w, exp_h = (osize, int(osize * height / width)) # (w, h)
self.assertEqual(result.size, expected_size, msg=msg) if max_size is not None and max_size < exp_h:
exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
else:
exp_w, exp_h = (osize, osize) # (w, h)
if max_size is not None and max_size < osize:
exp_w, exp_h = max_size, max_size
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
for height, width in input_sizes: for height, width in input_sizes:
img = Image.new("RGB", size=(width, height), color=127) img = Image.new("RGB", size=(width, height), color=127)
......
...@@ -7,6 +7,7 @@ from torchvision.transforms import InterpolationMode ...@@ -7,6 +7,7 @@ from torchvision.transforms import InterpolationMode
import numpy as np import numpy as np
import unittest import unittest
from typing import Sequence
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
...@@ -322,32 +323,29 @@ class Tester(TransformsTester): ...@@ -322,32 +323,29 @@ class Tester(TransformsTester):
tensor, _ = self._create_data(height=34, width=36, device=self.device) tensor, _ = self._create_data(height=34, width=36, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device) batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
script_fn = torch.jit.script(F.resize)
for dt in [None, torch.float32, torch.float64]: for dt in [None, torch.float32, torch.float64]:
if dt is not None: if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases # This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt) tensor = tensor.to(dt)
for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]: for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]: for max_size in (None, 35, 1000):
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
continue # Not supported
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation) if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
if isinstance(size, int): transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
script_size = [size, ] s_transform = torch.jit.script(transform)
else: self._test_transform_vs_scripted(transform, s_transform, tensor)
script_size = size self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
transform = T.Resize(size=script_size, interpolation=interpolation)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
script_fn.save(os.path.join(tmp_dir, "t_resize.pt")) s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
def test_resized_crop(self): def test_resized_crop(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device) tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
......
...@@ -337,7 +337,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool ...@@ -337,7 +337,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
return tensor return tensor
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR) -> Tensor: def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR,
max_size: Optional[int] = None) -> Tensor:
r"""Resize the input image to the given size. r"""Resize the input image to the given size.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
...@@ -355,6 +356,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte ...@@ -355,6 +356,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ```size` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
Returns: Returns:
PIL Image or Tensor: Resized image. PIL Image or Tensor: Resized image.
...@@ -372,9 +381,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte ...@@ -372,9 +381,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation] pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.resize(img, size=size, interpolation=pil_interpolation) return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
return F_t.resize(img, size=size, interpolation=interpolation.value) return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size)
def scale(*args, **kwargs): def scale(*args, **kwargs):
......
...@@ -204,27 +204,40 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag ...@@ -204,27 +204,40 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
@torch.jit.unused @torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR): def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))): if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size)) raise TypeError('Got inappropriate size arg: {}'.format(size))
if isinstance(size, int) or len(size) == 1: if isinstance(size, Sequence) and len(size) == 1:
if isinstance(size, Sequence): size = size[0]
size = size[0] if isinstance(size, int):
w, h = img.size w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
short, long = (w, h) if w <= h else (h, w)
if short == size:
return img return img
if w < h:
ow = size new_short, new_long = size, int(size * long / short)
oh = int(size * h / w)
return img.resize((ow, oh), interpolation) if max_size is not None:
else: if max_size <= size:
oh = size raise ValueError(
ow = int(size * w / h) f"max_size = {max_size} must be strictly greater than the requested "
return img.resize((ow, oh), interpolation) f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
return img.resize((new_w, new_h), interpolation)
else: else:
if max_size is not None:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
return img.resize(size[::-1], interpolation) return img.resize(size[::-1], interpolation)
......
...@@ -470,7 +470,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -470,7 +470,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
return img return img
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor: def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
if not isinstance(size, (int, tuple, list)): if not isinstance(size, (int, tuple, list)):
...@@ -484,34 +484,51 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten ...@@ -484,34 +484,51 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
if isinstance(size, tuple): if isinstance(size, tuple):
size = list(size) size = list(size)
if isinstance(size, list) and len(size) not in [1, 2]: if isinstance(size, list):
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a " if len(size) not in [1, 2]:
"{} element tuple/list".format(len(size))) raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
if max_size is not None and len(size) != 1:
raise ValueError(
"max_size should only be passed if size specifies the length of the smaller edge, "
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
)
w, h = _get_image_size(img) w, h = _get_image_size(img)
if isinstance(size, int): if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
size_w, size_h = size, size short, long = (w, h) if w <= h else (h, w)
elif len(size) < 2:
size_w, size_h = size[0], size[0]
else:
size_w, size_h = size[1], size[0] # Convention (h, w)
if isinstance(size, int) or len(size) < 2: if isinstance(size, int):
if w < h: requested_new_short = size
size_h = int(size_w * h / w)
else: else:
size_w = int(size_h * w / h) requested_new_short = size[0]
if (w <= h and w == size_w) or (h <= w and h == size_h): if short == requested_new_short:
return img return img
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
else: # specified both h and w
new_w, new_h = size[1], size[0]
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64]) img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
# Define align_corners to avoid warnings # Define align_corners to avoid warnings
align_corners = False if interpolation in ["bilinear", "bicubic"] else None align_corners = False if interpolation in ["bilinear", "bicubic"] else None
img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners) img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
if interpolation == "bicubic" and out_dtype == torch.uint8: if interpolation == "bicubic" and out_dtype == torch.uint8:
img = img.clamp(min=0, max=255) img = img.clamp(min=0, max=255)
......
...@@ -241,16 +241,25 @@ class Resize(torch.nn.Module): ...@@ -241,16 +241,25 @@ class Resize(torch.nn.Module):
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` and
``InterpolationMode.BICUBIC`` are supported. ``InterpolationMode.BICUBIC`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
the image is resized again so that the longer edge is equal to
``max_size``. As a result, ```size` might be overruled, i.e the
smaller edge may be shorter than ``size``. This is only supported
if ``size`` is an int (or a sequence of length 1 in torchscript
mode).
""" """
def __init__(self, size, interpolation=InterpolationMode.BILINEAR): def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None):
super().__init__() super().__init__()
if not isinstance(size, (int, Sequence)): if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size))) raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2): if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values") raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size self.size = size
self.max_size = max_size
# Backward compatibility with integer value # Backward compatibility with integer value
if isinstance(interpolation, int): if isinstance(interpolation, int):
...@@ -270,11 +279,12 @@ class Resize(torch.nn.Module): ...@@ -270,11 +279,12 @@ class Resize(torch.nn.Module):
Returns: Returns:
PIL Image or Tensor: Rescaled image. PIL Image or Tensor: Rescaled image.
""" """
return F.resize(img, self.size, self.interpolation) return F.resize(img, self.size, self.interpolation, self.max_size)
def __repr__(self): def __repr__(self):
interpolate_str = self.interpolation.value interpolate_str = self.interpolation.value
return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) return self.__class__.__name__ + '(size={0}, interpolation={1}, max_size={2})'.format(
self.size, interpolate_str, self.max_size)
class Scale(Resize): class Scale(Resize):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment