Unverified Commit e212cc86 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Unified input for resize op (#2394)

* [WIP] F.resize with tensor

* Adapted T.Resize and F.resize with a test

* According to the review, fixed copy-pasted messages and unused imports
parent 971c3e45
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
import numpy as np
import unittest import unittest
import random import random
import colorsys import colorsys
from PIL import Image from PIL import Image
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import numpy as np
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional_pil as F_pil
import torchvision.transforms.functional as F
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
...@@ -22,6 +25,14 @@ class Tester(unittest.TestCase): ...@@ -22,6 +25,14 @@ class Tester(unittest.TestCase):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))) pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor), msg) self.assertTrue(tensor.equal(pil_tensor), msg)
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
mae = torch.abs(tensor - pil_tensor).mean().item()
self.assertTrue(
mae < tol,
msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
)
def test_vflip(self): def test_vflip(self):
script_vflip = torch.jit.script(F_t.vflip) script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16) img_tensor = torch.randn(3, 16, 16)
...@@ -282,6 +293,44 @@ class Tester(unittest.TestCase): ...@@ -282,6 +293,44 @@ class Tester(unittest.TestCase):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"): with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric") F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
tensor, pil_img = self._create_data(26, 36)
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, [32, ], [32, 32], (32, 32), ]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation)
resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation)
self.assertEqual(
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
)
if interpolation != NEAREST:
# We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
resized_tensor_f = resized_tensor
# we need to cast to uint8 to compare with PIL image
if resized_tensor_f.dtype == torch.uint8:
resized_tensor_f = resized_tensor_f.to(torch.float)
# Pay attention to high tolerance for MAE
self.approxEqualTensorToPIL(
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
)
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from torchvision import transforms as T from torchvision import transforms as T
from torchvision.transforms import functional as F from torchvision.transforms import functional as F
from PIL import Image from PIL import Image
from PIL.Image import NEAREST, BILINEAR, BICUBIC
import numpy as np import numpy as np
...@@ -217,6 +218,33 @@ class Tester(unittest.TestCase): ...@@ -217,6 +218,33 @@ class Tester(unittest.TestCase):
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs "ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
) )
def test_resize(self):
tensor, _ = self._create_data(height=34, width=36)
script_fn = torch.jit.script(F.resize)
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for size in [32, [32, ], [32, 32], (32, 32), ]:
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
if isinstance(size, int):
script_size = [size, ]
else:
script_size = size
s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
transform = T.Resize(size=script_size, interpolation=interpolation)
resized_tensor = transform(tensor)
script_transform = torch.jit.script(transform)
s_resized_tensor = script_transform(tensor)
self.assertTrue(s_resized_tensor.equal(resized_tensor))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -311,41 +311,29 @@ def normalize(tensor, mean, std, inplace=False): ...@@ -311,41 +311,29 @@ def normalize(tensor, mean, std, inplace=False):
return tensor return tensor
def resize(img, size, interpolation=Image.BILINEAR): def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
r"""Resize the input PIL Image to the given size. r"""Resize the input image to the given size.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Args: Args:
img (PIL Image): Image to be resized. img (PIL Image or Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int, (h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)` :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
interpolation (int, optional): Desired interpolation. Default is In torchscript mode padding as single int is not supported, use a tuple or
``PIL.Image.BILINEAR`` list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation. Default is bilinear.
Returns: Returns:
PIL Image: Resized image. PIL Image or Tensor: Resized image.
""" """
if not F_pil._is_pil_image(img): if not isinstance(img, torch.Tensor):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.resize(img, size=size, interpolation=interpolation)
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
raise TypeError('Got inappropriate size arg: {}'.format(size))
if isinstance(size, int): return F_t.resize(img, size=size, interpolation=interpolation)
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)
def scale(*args, **kwargs): def scale(*args, **kwargs):
......
import numbers import numbers
from typing import Any, List from typing import Any, List, Sequence
import torch import torch
try: try:
...@@ -286,3 +286,44 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag ...@@ -286,3 +286,44 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.crop((left, top, left + width, top + height)) return img.crop((left, top, left + width, top + height))
@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR):
r"""Resize the input PIL Image to the given size.
Args:
img (PIL Image): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
For compatibility reasons with ``functional_tensor.resize``, if a tuple or list of length 1 is provided,
it is interpreted as a single int.
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``.
Returns:
PIL Image: Resized image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
raise TypeError('Got inappropriate size arg: {}'.format(size))
if isinstance(size, int) or len(size) == 1:
if isinstance(size, Sequence):
size = size[0]
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
return img.resize((ow, oh), interpolation)
else:
oh = size
ow = int(size * w / h)
return img.resize((ow, oh), interpolation)
else:
return img.resize(size[::-1], interpolation)
...@@ -8,6 +8,7 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool: ...@@ -8,6 +8,7 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
def _get_image_size(img: Tensor) -> List[int]: def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image"""
if _is_tensor_a_torch_image(img): if _is_tensor_a_torch_image(img):
return [img.shape[-1], img.shape[-2]] return [img.shape[-1], img.shape[-2]]
raise TypeError("Unexpected type {}".format(type(img))) raise TypeError("Unexpected type {}".format(type(img)))
...@@ -433,6 +434,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -433,6 +434,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
if isinstance(padding, int): if isinstance(padding, int):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
# This maybe unreachable
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]") raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
pad_left = pad_right = pad_top = pad_bottom = padding pad_left = pad_right = pad_top = pad_bottom = padding
elif len(padding) == 1: elif len(padding) == 1:
...@@ -480,3 +482,92 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con ...@@ -480,3 +482,92 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
img = img.to(out_dtype) img = img.to(out_dtype)
return img return img
def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
r"""Resize the input Tensor to the given size.
Args:
img (Tensor): Image to be resized.
size (int or tuple or list): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaining
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode padding as a single int is not supported, use a tuple or
list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation. Default is bilinear.
Returns:
Tensor: Resized image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
if not isinstance(interpolation, int):
raise TypeError("Got inappropriate interpolation arg")
_interpolation_modes = {
0: "nearest",
2: "bilinear",
3: "bicubic",
}
if interpolation not in _interpolation_modes:
raise ValueError("This interpolation mode is unsupported with Tensor input")
if isinstance(size, tuple):
size = list(size)
if isinstance(size, list) and len(size) not in [1, 2]:
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
"{} element tuple/list".format(len(size)))
w, h = _get_image_size(img)
if isinstance(size, int):
size_w, size_h = size, size
elif len(size) < 2:
size_w, size_h = size[0], size[0]
else:
size_w, size_h = size[0], size[1]
if isinstance(size, int) or len(size) < 2:
if w < h:
size_h = int(size_w * h / w)
else:
size_w = int(size_h * w / h)
if (w <= h and w == size_w) or (h <= w and h == size_h):
return img
# make image NCHW
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
mode = _interpolation_modes[interpolation]
out_dtype = img.dtype
need_cast = False
if img.dtype not in (torch.float32, torch.float64):
need_cast = True
img = img.to(torch.float32)
# Define align_corners to avoid warnings
align_corners = False if mode in ["bilinear", "bicubic"] else None
img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode, align_corners=align_corners)
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
if mode == "bicubic":
img = img.clamp(min=0, max=255)
img = img.to(out_dtype)
return img
...@@ -2,7 +2,7 @@ import math ...@@ -2,7 +2,7 @@ import math
import numbers import numbers
import random import random
import warnings import warnings
from collections.abc import Sequence, Iterable from collections.abc import Sequence
from typing import Tuple, List, Optional from typing import Tuple, List, Optional
import numpy as np import numpy as np
...@@ -209,31 +209,38 @@ class Normalize(object): ...@@ -209,31 +209,38 @@ class Normalize(object):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class Resize(object): class Resize(torch.nn.Module):
"""Resize the input PIL Image to the given size. """Resize the input image to the given size.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
Args: Args:
size (sequence or int): Desired output size. If size is a sequence like size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int, (h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number. smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to i.e, if height > width, then image will be rescaled to
(size * height / width, size) (size * height / width, size).
interpolation (int, optional): Desired interpolation. Default is In torchscript mode padding as single int is not supported, use a tuple or
``PIL.Image.BILINEAR`` list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``
""" """
def __init__(self, size, interpolation=Image.BILINEAR): def __init__(self, size, interpolation=Image.BILINEAR):
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2) super().__init__()
if not isinstance(size, (int, Sequence)):
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
if isinstance(size, Sequence) and len(size) not in (1, 2):
raise ValueError("If size is a sequence, it should have 1 or 2 values")
self.size = size self.size = size
self.interpolation = interpolation self.interpolation = interpolation
def __call__(self, img): def forward(self, img):
""" """
Args: Args:
img (PIL Image): Image to be scaled. img (PIL Image or Tensor): Image to be scaled.
Returns: Returns:
PIL Image: Rescaled image. PIL Image or Tensor: Rescaled image.
""" """
return F.resize(img, self.size, self.interpolation) return F.resize(img, self.size, self.interpolation)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment