Unverified Commit 5f4b5794 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Unified input for F.affine (#2444)

* [WIP] F.affine

* [WIP] F.affine + tests

* Unified input for F.affine

* Removed commented code

* Removed unused imports
parent 03b1d38b
...@@ -348,6 +348,95 @@ class Tester(unittest.TestCase): ...@@ -348,6 +348,95 @@ class Tester(unittest.TestCase):
msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10]) msg="{} vs {}".format(expected_out_tensor[0, :10, :10], out_tensor[0, :10, :10])
) )
def test_affine(self):
# Tests on square image
tensor, pil_img = self._create_data(26, 26)
scripted_affine = torch.jit.script(F.affine)
# 1) identity map
out_tensor = F.affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
out_tensor = scripted_affine(tensor, angle=0, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
self.assertTrue(
tensor.equal(out_tensor), msg="{} vs {}".format(out_tensor[0, :5, :5], tensor[0, :5, :5])
)
# 2) Test rotation
test_configs = [
(90, torch.rot90(tensor, k=1, dims=(-1, -2))),
(45, None),
(30, None),
(-30, None),
(-45, None),
(-90, torch.rot90(tensor, k=-1, dims=(-1, -2))),
(180, torch.rot90(tensor, k=2, dims=(-1, -2))),
]
for a, true_tensor in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
if true_tensor is not None:
self.assertTrue(
true_tensor.equal(out_tensor),
msg="{}\n{} vs \n{}".format(a, out_tensor[0, :5, :5], true_tensor[0, :5, :5])
)
else:
true_tensor = out_tensor
out_pil_img = F.affine(pil_img, angle=a, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (true_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / true_tensor.shape[-1] / true_tensor.shape[-2]
# Tolerance : less than 6% of different pixels
self.assertLess(
ratio_diff_pixels,
0.06,
msg="{}\n{} vs \n{}".format(
ratio_diff_pixels, true_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
# 3) Test translation
test_configs = [
[10, 12], (12, 13)
]
for t in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
out_pil_img = F.affine(pil_img, angle=0, translate=t, scale=1.0, shear=[0.0, 0.0], resample=0)
self.compareTensorToPIL(out_tensor, out_pil_img)
# 3) Test rotation + translation + scale + share
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [5, 4], 1.2, [0.0, 0.0]),
(33, (4, 8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(25, [0, 0], 1.2, [0.0, 15.0]),
(45, [10, 0], 0.7, [2.0, 5.0]),
(45, [10, -10], 1.2, [4.0, 5.0]),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
self.assertLess(
ratio_diff_pixels,
0.05,
msg="{}: {}\n{} vs \n{}".format(
(r, a, t, s, sh), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -1317,8 +1317,8 @@ class Tester(unittest.TestCase): ...@@ -1317,8 +1317,8 @@ class Tester(unittest.TestCase):
for j in range(-5, 5): for j in range(-5, 5):
input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55] input_img[pt[0] + i, pt[1] + j, :] = [255, 155, 55]
with self.assertRaises(TypeError): with self.assertRaises(TypeError, msg="Argument translate should be a sequence"):
F.affine(input_img, 10) F.affine(input_img, 10, translate=0, scale=1, shear=1)
pil_img = F.to_pil_image(input_img) pil_img = F.to_pil_image(input_img)
......
import math import math
import numbers import numbers
import warnings import warnings
from typing import Any from typing import Any, Optional
import numpy as np import numpy as np
from numpy import sin, cos, tan from PIL import Image
from PIL import Image, __version__ as PILLOW_VERSION
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -21,6 +20,7 @@ from . import functional_tensor as F_t ...@@ -21,6 +20,7 @@ from . import functional_tensor as F_t
_is_pil_image = F_pil._is_pil_image _is_pil_image = F_pil._is_pil_image
_parse_fill = F_pil._parse_fill
def _get_image_size(img: Tensor) -> List[int]: def _get_image_size(img: Tensor) -> List[int]:
...@@ -485,43 +485,6 @@ def hflip(img: Tensor) -> Tensor: ...@@ -485,43 +485,6 @@ def hflip(img: Tensor) -> Tensor:
return F_t.hflip(img) return F_t.hflip(img)
def _parse_fill(fill, img, min_pil_version):
"""Helper function to get the fill color for rotate and perspective transforms.
Args:
fill (n-tuple or int or float): Pixel fill value for area outside the transformed
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands.
img (PIL Image): Image to be filled.
min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option
was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0)
Returns:
dict: kwarg for ``fillcolor``
"""
major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and minor_found < minor_required):
if fill is None:
return {}
else:
msg = ("The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
return {"fillcolor": fill}
def _get_perspective_coeffs(startpoints, endpoints): def _get_perspective_coeffs(startpoints, endpoints):
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
...@@ -827,7 +790,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None): ...@@ -827,7 +790,9 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
return img.rotate(angle, resample, expand, center, **opts) return img.rotate(angle, resample, expand, center, **opts)
def _get_inverse_affine_matrix(center, angle, translate, scale, shear): def _get_inverse_affine_matrix(
center: List[int], angle: float, translate: List[float], scale: float, shear: List[float]
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation # Helper method to compute inverse matrix for affine transformation
# As it is explained in PIL.Image.rotate # As it is explained in PIL.Image.rotate
...@@ -847,14 +812,6 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): ...@@ -847,14 +812,6 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
# #
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1 # Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
if isinstance(shear, numbers.Number):
shear = [shear, 0]
if not isinstance(shear, (tuple, list)) and len(shear) == 2:
raise ValueError(
"Shear should be a single value or a tuple/list containing " +
"two values. Got {}".format(shear))
rot = math.radians(angle) rot = math.radians(angle)
sx, sy = [math.radians(s) for s in shear] sx, sy = [math.radians(s) for s in shear]
...@@ -862,32 +819,37 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear): ...@@ -862,32 +819,37 @@ def _get_inverse_affine_matrix(center, angle, translate, scale, shear):
tx, ty = translate tx, ty = translate
# RSS without scaling # RSS without scaling
a = cos(rot - sy) / cos(sy) a = math.cos(rot - sy) / math.cos(sy)
b = -cos(rot - sy) * tan(sx) / cos(sy) - sin(rot) b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
c = sin(rot - sy) / cos(sy) c = math.sin(rot - sy) / math.cos(sy)
d = -sin(rot - sy) * tan(sx) / cos(sy) + cos(rot) d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
# Inverted rotation matrix with scale and shear # Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1 # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
M = [d, -b, 0, matrix = [d, -b, 0.0, -c, a, 0.0]
-c, a, 0] matrix = [x / scale for x in matrix]
M = [x / scale for x in M]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1 # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
M[2] += M[0] * (-cx - tx) + M[1] * (-cy - ty) matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
M[5] += M[3] * (-cx - tx) + M[4] * (-cy - ty) matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
# Apply center translation: C * RSS^-1 * C^-1 * T^-1 # Apply center translation: C * RSS^-1 * C^-1 * T^-1
M[2] += cx matrix[2] += cx
M[5] += cy matrix[5] += cy
return M
return matrix
def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
"""Apply affine transformation on the image keeping image center invariant def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args: Args:
img (PIL Image): PIL Image to be rotated. img (PIL Image or Tensor): image to be rotated.
angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction. angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation) translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale scale (float): overall scale
...@@ -895,27 +857,62 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None): ...@@ -895,27 +857,62 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while If a tuple of list is specified, the first value corresponds to a shear parallel to the x axis, while
the second value corresponds to a shear parallel to the y axis. the second value corresponds to a shear parallel to the y axis.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional): resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. An optional resampling filter. See `filters`_ for more information.
See `filters`_ for more information. If omitted, or if the image is PIL Image and has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``. If input is Tensor, only ``PIL.Image.NEAREST`` and ``PIL.Image.BILINEAR`` are supported.
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0) fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
Returns:
PIL Image or Tensor: Transformed image.
""" """
if not F_pil._is_pil_image(img): if not isinstance(angle, (int, float)):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError("Argument angle should be int or float")
if not isinstance(translate, (list, tuple)):
raise TypeError("Argument translate should be a sequence")
if len(translate) != 2:
raise ValueError("Argument translate should be a sequence of length 2")
if scale <= 0.0:
raise ValueError("Argument scale should be positive")
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \ if isinstance(angle, int):
"Argument translate should be a list or tuple of length 2" angle = float(angle)
assert scale > 0.0, "Argument scale should be positive" if isinstance(translate, tuple):
translate = list(translate)
output_size = img.size if isinstance(shear, numbers.Number):
# center = (img.size[0] * 0.5 + 0.5, img.size[1] * 0.5 + 0.5) shear = [shear, 0.0]
if isinstance(shear, tuple):
shear = list(shear)
if len(shear) == 1:
shear = [shear[0], shear[0]]
if len(shear) != 2:
raise ValueError("Shear should be a sequence containing two values. Got {}".format(shear))
img_size = _get_image_size(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset # it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted 1 pixel # otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
center = (img.size[0] * 0.5, img.size[1] * 0.5) center = [img_size[0] * 0.5, img_size[1] * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear) matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
kwargs = {"fillcolor": fillcolor} if int(PILLOW_VERSION.split('.')[0]) >= 5 else {}
return img.transform(output_size, Image.AFFINE, matrix, resample, **kwargs) return F_pil.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
# we need to rescale translate by image size / 2 as its values can be between -1 and 1
translate = [2.0 * t / s for s, t in zip(img_size, translate)]
matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear)
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
def to_grayscale(img, num_output_channels=1): def to_grayscale(img, num_output_channels=1):
......
import numbers import numbers
from typing import Any, List, Sequence from typing import Any, List, Sequence
import numpy as np
import torch import torch
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
try: try:
import accimage import accimage
except ImportError: except ImportError:
accimage = None accimage = None
from PIL import Image, ImageOps, ImageEnhance
import numpy as np
@torch.jit.unused @torch.jit.unused
...@@ -327,3 +328,65 @@ def resize(img, size, interpolation=Image.BILINEAR): ...@@ -327,3 +328,65 @@ def resize(img, size, interpolation=Image.BILINEAR):
return img.resize((ow, oh), interpolation) return img.resize((ow, oh), interpolation)
else: else:
return img.resize(size[::-1], interpolation) return img.resize(size[::-1], interpolation)
@torch.jit.unused
def _parse_fill(fill, img, min_pil_version):
"""Helper function to get the fill color for rotate and perspective transforms.
Args:
fill (n-tuple or int or float): Pixel fill value for area outside the transformed
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands.
img (PIL Image): Image to be filled.
min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option
was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0)
Returns:
dict: kwarg for ``fillcolor``
"""
major_found, minor_found = (int(v) for v in PILLOW_VERSION.split('.')[:2])
major_required, minor_required = (int(v) for v in min_pil_version.split('.')[:2])
if major_found < major_required or (major_found == major_required and minor_found < minor_required):
if fill is None:
return {}
else:
msg = ("The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
return {"fillcolor": fill}
@torch.jit.unused
def affine(img, matrix, resample=0, fillcolor=None):
"""Apply affine transformation on the PIL Image keeping image center invariant.
Args:
img (PIL Image): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter.
See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
Returns:
PIL Image: Transformed image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
output_size = img.size
opts = _parse_fill(fillcolor, img, '5.0.0')
return img.transform(output_size, Image.AFFINE, matrix, resample, **opts)
import warnings
from typing import Optional
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn.functional import affine_grid, grid_sample
from torch.jit.annotations import List, BroadcastingList2 from torch.jit.annotations import List, BroadcastingList2
...@@ -496,7 +500,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -496,7 +500,8 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`. :math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
In torchscript mode padding as a single int is not supported, use a tuple or In torchscript mode padding as a single int is not supported, use a tuple or
list of length 1: ``[size, ]``. list of length 1: ``[size, ]``.
interpolation (int, optional): Desired interpolation. Default is bilinear. interpolation (int, optional): Desired interpolation. Default is bilinear (=2). Other supported values:
nearest(=0) and bicubic(=3).
Returns: Returns:
Tensor: Resized image. Tensor: Resized image.
...@@ -571,3 +576,63 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor: ...@@ -571,3 +576,63 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
img = img.to(out_dtype) img = img.to(out_dtype)
return img return img
def affine(
img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
"""Apply affine transformation on the Tensor image keeping image center invariant.
Args:
img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values:
bilinear(=2).
fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
transform in the output image is always 0.
Returns:
Tensor: Transformed image.
"""
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
if fillcolor is not None:
warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero")
_interpolation_modes = {
0: "nearest",
2: "bilinear",
}
if resample not in _interpolation_modes:
raise ValueError("This resampling mode is unsupported with Tensor input")
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
shape = img.shape
grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False)
# make image NCHW
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
mode = _interpolation_modes[resample]
out_dtype = img.dtype
need_cast = False
if img.dtype not in (torch.float32, torch.float64):
need_cast = True
img = img.to(torch.float32)
img = grid_sample(img, grid, mode=mode, padding_mode="zeros", align_corners=False)
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
# it is better to round before cast
img = torch.round(img).to(out_dtype)
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment