Unverified Commit 76662528 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Unified inputs for `F.rotate` (#2495)

* Added code for F_t.rotate with test
- updated F.affine tests

* Rotate test tolerance to 2%

* Fixes failing test

* Optimized _expanded_affine_grid with a single matmul op

* Recoded _compute_output_size
parent 23295fbb
......@@ -435,7 +435,7 @@ class Tester(unittest.TestCase):
)
# 3) Test translation
test_configs = [
[10, 12], (12, 13)
[10, 12], (-12, -13)
]
for t in test_configs:
for fn in [F.affine, scripted_affine]:
......@@ -447,21 +447,21 @@ class Tester(unittest.TestCase):
test_configs = [
(45, [5, 6], 1.0, [0.0, 0.0]),
(33, (5, -4), 1.0, [0.0, 0.0]),
(45, [5, 4], 1.2, [0.0, 0.0]),
(33, (4, 8), 2.0, [0.0, 0.0]),
(45, [-5, 4], 1.2, [0.0, 0.0]),
(33, (-4, -8), 2.0, [0.0, 0.0]),
(85, (10, -10), 0.7, [0.0, 0.0]),
(0, [0, 0], 1.0, [35.0, ]),
(25, [0, 0], 1.2, [0.0, 15.0]),
(45, [10, 0], 0.7, [2.0, 5.0]),
(45, [10, -10], 1.2, [4.0, 5.0]),
(45, [-10, 0], 0.7, [2.0, 5.0]),
(45, [-10, -10], 1.2, [4.0, 5.0]),
]
for r in [0, ]:
for a, t, s, sh in test_configs:
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.affine, scripted_affine]:
out_tensor = fn(tensor, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_img = F.affine(pil_img, angle=a, translate=t, scale=s, shear=sh, resample=r)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 5% of different pixels
......@@ -473,6 +473,47 @@ class Tester(unittest.TestCase):
)
)
def test_rotate(self):
# Tests on square image
tensor, pil_img = self._create_data(26, 26)
scripted_rotate = torch.jit.script(F.rotate)
img_size = pil_img.size
centers = [
None,
(int(img_size[0] * 0.3), int(img_size[0] * 0.4)),
[int(img_size[0] * 0.5), int(img_size[0] * 0.6)]
]
for r in [0, ]:
for a in range(-120, 120, 23):
for e in [True, False]:
for c in centers:
out_pil_img = F.rotate(pil_img, angle=a, resample=r, expand=e, center=c)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
for fn in [F.rotate, scripted_rotate]:
out_tensor = fn(tensor, angle=a, resample=r, expand=e, center=c)
self.assertEqual(
out_tensor.shape,
out_pil_tensor.shape,
msg="{}: {} vs {}".format(
(r, a, e, c), out_tensor.shape, out_pil_tensor.shape
)
)
num_diff_pixels = (out_tensor != out_pil_tensor).sum().item() / 3.0
ratio_diff_pixels = num_diff_pixels / out_tensor.shape[-1] / out_tensor.shape[-2]
# Tolerance : less than 2% of different pixels
self.assertLess(
ratio_diff_pixels,
0.02,
msg="{}: {}\n{} vs \n{}".format(
(r, a, e, c), ratio_diff_pixels, out_tensor[0, :7, :7], out_pil_tensor[0, :7, :7]
)
)
if __name__ == '__main__':
unittest.main()
......@@ -1266,7 +1266,7 @@ class Tester(unittest.TestCase):
x = np.zeros((100, 100, 3), dtype=np.uint8)
x[40, 40] = [255, 255, 255]
with self.assertRaises(TypeError):
with self.assertRaisesRegex(TypeError, r"img should be PIL Image"):
F.rotate(x, 10)
img = F.to_pil_image(x)
......
......@@ -756,40 +756,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
return F_t.adjust_gamma(img, gamma, gain)
def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
"""Rotate the image by angle.
Args:
img (PIL Image): PIL Image to be rotated.
angle (float or int): In degrees degrees counter clockwise order.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (2-tuple, optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
if not F_pil._is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = _parse_fill(fill, img, '5.2.0')
return img.rotate(angle, resample, expand, center, **opts)
def _get_inverse_affine_matrix(
center: List[int], angle: float, translate: List[float], scale: float, shear: List[float]
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
......@@ -838,6 +806,56 @@ def _get_inverse_affine_matrix(
return matrix
def rotate(
img: Tensor, angle: float, resample: int = 0, expand: bool = False,
center: Optional[List[int]] = None, fill: Optional[int] = None
) -> Tensor:
"""Rotate the image by angle.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): image to be rotated.
angle (float or int): rotation angle value in degrees, counter-clockwise.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (list or tuple, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
Returns:
PIL Image or Tensor: Rotated image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
if not isinstance(img, torch.Tensor):
return F_pil.rotate(img, angle=angle, resample=resample, expand=expand, center=center, fill=fill)
center_f = [0.0, 0.0]
if center is not None:
img_size = _get_image_size(img)
# Center is normalized to [-1, +1]
center_f = [2.0 * t / s - 1.0 for s, t in zip(img_size, center)]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
return F_t.rotate(img, matrix=matrix, resample=resample, expand=expand, fill=fill)
def affine(
img: Tensor, angle: float, translate: List[int], scale: float, shear: List[float],
resample: int = 0, fillcolor: Optional[int] = None
......@@ -847,7 +865,7 @@ def affine(
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
img (PIL Image or Tensor): image to be rotated.
img (PIL Image or Tensor): image to transform.
angle (float or int): rotation angle in degrees between -180 and 180, clockwise direction.
translate (list or tuple of integers): horizontal and vertical translations (post-rotation translation)
scale (float): overall scale
......@@ -911,7 +929,7 @@ def affine(
# we need to rescale translate by image size / 2 as its values can be between -1 and 1
translate = [2.0 * t / s for s, t in zip(img_size, translate)]
matrix = _get_inverse_affine_matrix([0, 0], angle, translate, scale, shear)
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
return F_t.affine(img, matrix=matrix, resample=resample, fillcolor=fillcolor)
......
......@@ -422,3 +422,37 @@ def affine(img, matrix, resample=0, fillcolor=None):
output_size = img.size
opts = _parse_fill(fillcolor, img, '5.0.0')
return img.transform(output_size, Image.AFFINE, matrix, resample, **opts)
@torch.jit.unused
def rotate(img, angle, resample=0, expand=False, center=None, fill=None):
"""Rotate PIL image by angle.
Args:
img (PIL Image): image to be rotated.
angle (float or int): rotation angle value in degrees, counter-clockwise.
resample (``PIL.Image.NEAREST`` or ``PIL.Image.BILINEAR`` or ``PIL.Image.BICUBIC``, optional):
An optional resampling filter. See `filters`_ for more information.
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
center (2-tuple, optional): Optional center of rotation.
Origin is the upper left corner.
Default is the center of the image.
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands. This option is only available for ``pillow>=5.2.0``.
Returns:
PIL Image: Rotated image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
opts = _parse_fill(fill, img, '5.2.0')
return img.rotate(angle, resample, expand, center, **opts)
import warnings
from typing import Optional
from typing import Optional, Dict, Tuple
import torch
from torch import Tensor
......@@ -619,48 +619,32 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
return img
def affine(
img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
"""Apply affine transformation on the Tensor image keeping image center invariant.
def _assert_grid_transform_inputs(
img: Tensor, matrix: List[float], resample: int, fillcolor: Optional[int], _interpolation_modes: Dict[int, str]
):
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError("img should be Tensor Image. Got {}".format(type(img)))
Args:
img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
resample (int, optional): An optional resampling filter. Default is nearest (=2). Other supported values:
bilinear(=2).
fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
transform in the output image is always 0.
if not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list. Got {}".format(type(matrix)))
Returns:
Tensor: Transformed image.
"""
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))
if len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if fillcolor is not None:
warnings.warn("Argument fillcolor is not supported for Tensor input. Fill value is zero")
_interpolation_modes = {
0: "nearest",
2: "bilinear",
}
warnings.warn("Argument fill/fillcolor is not supported for Tensor input. Fill value is zero")
if resample not in _interpolation_modes:
raise ValueError("This resampling mode is unsupported with Tensor input")
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
shape = img.shape
grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False)
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
# make image NCHW
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
mode = _interpolation_modes[resample]
out_dtype = img.dtype
need_cast = False
if img.dtype not in (torch.float32, torch.float64):
......@@ -677,3 +661,106 @@ def affine(
img = torch.round(img).to(out_dtype)
return img
def affine(
img: Tensor, matrix: List[float], resample: int = 0, fillcolor: Optional[int] = None
) -> Tensor:
"""Apply affine transformation on the Tensor image keeping image center invariant.
Args:
img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for affine transformation.
resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values:
bilinear(=2).
fillcolor (int, optional): this option is not supported for Tensor input. Fill value for the area outside the
transform in the output image is always 0.
Returns:
Tensor: Transformed image.
"""
_interpolation_modes = {
0: "nearest",
2: "bilinear",
}
_assert_grid_transform_inputs(img, matrix, resample, fillcolor, _interpolation_modes)
theta = torch.tensor(matrix, dtype=torch.float).reshape(1, 2, 3)
shape = img.shape
grid = affine_grid(theta, size=(1, shape[-3], shape[-2], shape[-1]), align_corners=False)
mode = _interpolation_modes[resample]
return _apply_grid_transform(img, grid, mode)
def _compute_output_size(theta: Tensor, w: int, h: int) -> Tuple[int, int]:
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# we need to normalize coordinates according to
# [0, s] is mapped [-1, +1] as theta translation parameters are normalized like that
pts = torch.tensor([
[-1.0, -1.0, 1.0],
[-1.0, 1.0, 1.0],
[1.0, 1.0, 1.0],
[1.0, -1.0, 1.0],
])
# denormalize back to w, h:
new_pts = (torch.matmul(pts, theta.t()) + 1.0) * torch.tensor([w, h]) / 2.0
min_vals, _ = new_pts.min(dim=0)
max_vals, _ = new_pts.max(dim=0)
size = torch.ceil(max_vals) - torch.floor(min_vals)
return int(size[0]), int(size[1])
def _expanded_affine_grid(theta: Tensor, w: int, h: int, expand: bool = False) -> Tensor:
if expand:
ow, oh = _compute_output_size(theta, w, h)
else:
ow, oh = w, h
d = 0.5 # if not align_corners
x = (torch.arange(ow) + d - ow * 0.5) / (0.5 * w)
y = (torch.arange(oh) + d - oh * 0.5) / (0.5 * h)
y, x = torch.meshgrid(y, x)
pts = torch.stack([x, y, torch.ones_like(x)], dim=-1)
output_grid = torch.matmul(pts, theta.t())
return output_grid.unsqueeze(dim=0)
def rotate(
img: Tensor, matrix: List[float], resample: int = 0, expand: bool = False, fill: Optional[int] = None
) -> Tensor:
"""Rotate the Tensor image by angle.
Args:
img (Tensor): image to be rotated.
matrix (list of floats): list of 6 float values representing inverse matrix for rotation transformation.
resample (int, optional): An optional resampling filter. Default is nearest (=0). Other supported values:
bilinear(=2).
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Note that the expand flag assumes rotation around the center and no translation.
fill (n-tuple or int or float): this option is not supported for Tensor input.
Fill value for the area outside the transform in the output image is always 0.
Returns:
Tensor: Rotated image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
_interpolation_modes = {
0: "nearest",
2: "bilinear",
}
_assert_grid_transform_inputs(img, matrix, resample, fill, _interpolation_modes)
theta = torch.tensor(matrix).reshape(2, 3)
shape = img.shape
grid = _expanded_affine_grid(theta, shape[-1], shape[-2], expand=expand)
mode = _interpolation_modes[resample]
return _apply_grid_transform(img, grid, mode)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment