Unverified Commit 59c723cb authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added center arg to F.affine and RandomAffine ops (#5208)

* Added center option to F.affine and RandomAffine ops

* Updates according to the review
parent abdae5a1
......@@ -232,7 +232,8 @@ class TestAffine:
@pytest.mark.parametrize("dt", ALL_DTYPES)
@pytest.mark.parametrize("angle", [90, 45, 15, -30, -60, -120])
@pytest.mark.parametrize("fn", [F.affine, scripted_affine])
def test_rect_rotations(self, device, height, width, dt, angle, fn):
@pytest.mark.parametrize("center", [None, [0, 0]])
def test_rect_rotations(self, device, height, width, dt, angle, fn, center):
# Tests on rectangular images
tensor, pil_img = _create_data(height, width, device=device)
......@@ -244,11 +245,13 @@ class TestAffine:
tensor = tensor.to(dtype=dt)
out_pil_img = F.affine(
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST
pil_img, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
)
out_pil_tensor = torch.from_numpy(np.array(out_pil_img).transpose((2, 0, 1)))
out_tensor = fn(tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST).cpu()
out_tensor = fn(
tensor, angle=angle, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=NEAREST, center=center
).cpu()
if out_tensor.dtype != torch.uint8:
out_tensor = out_tensor.to(torch.uint8)
......
......@@ -1983,11 +1983,11 @@ class TestAffine:
result_matrix[2, 2] = 1
return np.linalg.inv(result_matrix)
def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img):
def _test_transformation(self, angle, translate, scale, shear, pil_image, input_img, center=None):
a_rad = math.radians(angle)
s_rad = [math.radians(sh_) for sh_ in shear]
cnt = [20, 20]
cnt = [20, 20] if center is None else center
cx, cy = cnt
tx, ty = translate
sx, sy = s_rad
......@@ -2032,7 +2032,7 @@ class TestAffine:
if 0 <= _x < input_img.shape[1] and 0 <= _y < input_img.shape[0]:
true_result[y, x, :] = input_img[_y, _x, :]
result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear)
result = F.affine(pil_image, angle=angle, translate=translate, scale=scale, shear=shear, center=center)
assert result.size == pil_image.size
# Compute number of different pixels:
np_result = np.array(result)
......@@ -2050,6 +2050,18 @@ class TestAffine:
angle=angle, translate=(0, 0), scale=1.0, shear=(0.0, 0.0), pil_image=pil_image, input_img=input_img
)
# Test rotation
angle = 45
self._test_transformation(
angle=angle,
translate=(0, 0),
scale=1.0,
shear=(0.0, 0.0),
pil_image=pil_image,
input_img=input_img,
center=[0, 0],
)
# Test translation
translate = [10, 15]
self._test_transformation(
......@@ -2068,6 +2080,18 @@ class TestAffine:
angle=0.0, translate=(0.0, 0.0), scale=1.0, shear=shear, pil_image=pil_image, input_img=input_img
)
# Test shear with top-left as center
shear = [45.0, 25.0]
self._test_transformation(
angle=0.0,
translate=(0.0, 0.0),
scale=1.0,
shear=shear,
pil_image=pil_image,
input_img=input_img,
center=[0, 0],
)
@pytest.mark.parametrize("angle", range(-90, 90, 36))
@pytest.mark.parametrize("translate", range(-10, 10, 5))
@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
......
......@@ -945,26 +945,31 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float]
center: List[float],
angle: float,
translate: List[float],
scale: float,
shear: List[float],
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
# As it is explained in PIL.Image.rotate
# We need compute INVERSE of affine transformation matrix: M = T * C * RSS * C^-1
# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RSS is rotation with scale and shear matrix
# RSS(a, s, (sx, sy)) =
# RotateScaleShear is rotation with scale and shear matrix
#
# RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(x)/cos(y) - sin(a)), 0 ]
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(x)/cos(y) + cos(a)), 0 ]
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a + sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
#
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RSS^-1 * C^-1 * T^-1
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
rot = math.radians(angle)
sx = math.radians(shear[0])
......@@ -1085,6 +1090,7 @@ def affine(
fill: Optional[List[float]] = None,
resample: Optional[int] = None,
fillcolor: Optional[List[float]] = None,
center: Optional[List[int]] = None,
) -> Tensor:
"""Apply affine transformation on the image keeping image center invariant.
If the image is torch Tensor, it is expected
......@@ -1112,6 +1118,8 @@ def affine(
Please use the ``fill`` parameter instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead.
center (sequence, optional): Optional center of rotation. Origin is the upper left corner.
Default is the center of the image.
Returns:
PIL Image or Tensor: Transformed image.
......@@ -1172,18 +1180,28 @@ def affine(
if len(shear) != 2:
raise ValueError(f"Shear should be a sequence containing two values. Got {shear}")
if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")
img_size = get_image_size(img)
if not isinstance(img, torch.Tensor):
# center = (img_size[0] * 0.5 + 0.5, img_size[1] * 0.5 + 0.5)
# it is visually better to estimate the center without 0.5 offset
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
if center is None:
center = [img_size[0] * 0.5, img_size[1] * 0.5]
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.affine(img, matrix=matrix, interpolation=pil_interpolation, fill=fill)
center_f = [0.0, 0.0]
if center is not None:
img_size = get_image_size(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, img_size)]
translate_f = [1.0 * t for t in translate]
matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, translate_f, scale, shear)
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
return F_t.affine(img, matrix=matrix, interpolation=interpolation.value, fill=fill)
......
......@@ -1414,6 +1414,8 @@ class RandomAffine(torch.nn.Module):
Please use the ``fill`` parameter instead.
resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image.
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
......@@ -1429,6 +1431,7 @@ class RandomAffine(torch.nn.Module):
fill=0,
fillcolor=None,
resample=None,
center=None,
):
super().__init__()
_log_api_usage_once(self)
......@@ -1482,6 +1485,11 @@ class RandomAffine(torch.nn.Module):
self.fillcolor = self.fill = fill
if center is not None:
_check_sequence_input(center, "center", req_sizes=(2,))
self.center = center
@staticmethod
def get_params(
degrees: List[float],
......@@ -1538,7 +1546,7 @@ class RandomAffine(torch.nn.Module):
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img_size)
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill)
return F.affine(img, *ret, interpolation=self.interpolation, fill=fill, center=self.center)
def __repr__(self):
s = "{name}(degrees={degrees}"
......@@ -1552,6 +1560,8 @@ class RandomAffine(torch.nn.Module):
s += ", interpolation={interpolation}"
if self.fill != 0:
s += ", fill={fill}"
if self.center is not None:
s += ", center={center}"
s += ")"
d = dict(self.__dict__)
d["interpolation"] = self.interpolation.value
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment