Unverified Commit b1f6c9e2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Optimize and clean up all affine methods (#6945)

* Clean up `_get_inverse_affine_matrix` and `_compute_affine_output_size`

* Optimize `_apply_grid_transform`

* Cleanup `_assert_grid_transform_inputs`

* Fix bugs on `_pad_with_scalar_fill` & `crop_mask` and port `crop_image_tensor`

* Call directly `_pad_with_scalar_fill`

* Fix linter

* Clean up `center_crop_image_tensor`

* Fix comments.

* Fixing rounding issues.

* Bumping tolerance for rotate which is unrelated to this PR.

* Fix tolerance threshold for RandomPerspective.

* Clean up `_affine_grid` and `affine_image_tensor`

* Clean up `rotate_image_tensor`

* Fixing linter

* Address code-review comments.
parent de350bc0
...@@ -915,7 +915,7 @@ KERNEL_INFOS.extend( ...@@ -915,7 +915,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_rotate_image_tensor, reference_inputs_fn=reference_inputs_rotate_image_tensor,
float32_vs_uint8=True, float32_vs_uint8=True,
# TODO: investigate # TODO: investigate
closeness_kwargs=pil_reference_pixel_difference(100, agg_method="mean"), closeness_kwargs=pil_reference_pixel_difference(110, agg_method="mean"),
test_marks=[ test_marks=[
xfail_jit_tuple_instead_of_list("fill"), xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok # TODO: check if this is a regression since it seems that should be supported if `int` is ok
......
...@@ -401,6 +401,7 @@ CONSISTENCY_CONFIGS = [ ...@@ -401,6 +401,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, distortion_scale=0.1, fill=1), ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)), ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
], ],
closeness_kwargs={"atol": None, "rtol": None},
), ),
ConsistencyConfig( ConsistencyConfig(
prototype_transforms.RandomRotation, prototype_transforms.RandomRotation,
......
import math
import numbers import numbers
import warnings import warnings
from typing import List, Optional, Sequence, Tuple, Union from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image import PIL.Image
import torch import torch
from torch.nn.functional import interpolate, pad as torch_pad from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import ( from torchvision.transforms.functional import (
_compute_resized_output_size as __compute_resized_output_size, _compute_resized_output_size as __compute_resized_output_size,
_get_inverse_affine_matrix,
_get_perspective_coeffs, _get_perspective_coeffs,
InterpolationMode, InterpolationMode,
pil_modes_mapping, pil_modes_mapping,
...@@ -272,6 +272,195 @@ def _affine_parse_args( ...@@ -272,6 +272,195 @@ def _affine_parse_args(
return angle, translate, shear, center return angle, translate, shear, center
def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RotateScaleShear is rotation with scale and shear matrix
#
# RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
rot = math.radians(angle)
sx = math.radians(shear[0])
sy = math.radians(shear[1])
cx, cy = center
tx, ty = translate
# Cached results
cos_sy = math.cos(sy)
tan_sx = math.tan(sx)
rot_minus_sy = rot - sy
cx_plus_tx = cx + tx
cy_plus_ty = cy + ty
# Rotate Scale Shear (RSS) without scaling
a = math.cos(rot_minus_sy) / cos_sy
b = -(a * tan_sx + math.sin(rot))
c = math.sin(rot_minus_sy) / cos_sy
d = math.cos(rot) - c * tan_sx
if inverted:
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
# and then apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
else:
matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
# Apply inverse of center translation: RSS * C^-1
# and then apply translation and center : T * C * RSS * C^-1
matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy
return matrix
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
half_w = 0.5 * w
half_h = 0.5 * h
pts = torch.tensor(
[
[-half_w, -half_h, 1.0],
[-half_w, half_h, 1.0],
[half_w, half_h, 1.0],
[half_w, -half_h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, max_vals = new_pts.aminmax(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
halfs = torch.tensor((half_w, half_h))
min_vals.add_(halfs)
max_vals.add_(halfs)
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
inv_tol = 1.0 / tol
cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
size = cmax.sub_(cmin)
return int(size[0]), int(size[1]) # w, h
def _apply_grid_transform(
float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: features.FillTypeJIT
) -> torch.Tensor:
shape = float_img.shape
if shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(shape[0], -1, -1, -1)
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
float_img = torch.cat((float_img, mask), dim=1)
float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
else: # 'bilinear'
# The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
return float_img
def _assert_grid_transform_inputs(
image: torch.Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: features.FillTypeJIT,
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if matrix is not None:
if not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
elif len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None:
if isinstance(fill, (tuple, list)):
length = len(fill)
num_channels = image.shape[-3]
if length > 1 and length != num_channels:
raise ValueError(
"The number of elements in 'fill' cannot broadcast to match the number of "
f"channels of the image ({length} != {num_channels})"
)
elif not isinstance(fill, (int, float)):
raise ValueError("Argument fill should be either int, float, tuple or list")
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _affine_grid(
theta: torch.Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> torch.Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
dtype = theta.dtype
device = theta.device
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine_image_tensor( def affine_image_tensor(
image: torch.Tensor, image: torch.Tensor,
angle: Union[int, float], angle: Union[int, float],
...@@ -286,9 +475,19 @@ def affine_image_tensor( ...@@ -286,9 +475,19 @@ def affine_image_tensor(
return image return image
shape = image.shape shape = image.shape
num_channels, height, width = shape[-3:] ndim = image.ndim
image = image.reshape(-1, num_channels, height, width) fp = torch.is_floating_point(image)
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
height, width = shape[-2:]
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
...@@ -299,8 +498,20 @@ def affine_image_tensor( ...@@ -299,8 +498,20 @@ def affine_image_tensor(
translate_f = [float(t) for t in translate] translate_f = [float(t) for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear) matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill) _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
return output.reshape(shape)
dtype = image.dtype if fp else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash:
output = output.reshape(shape)
return output
@torch.jit.unused @torch.jit.unused
...@@ -395,7 +606,7 @@ def _affine_bounding_box_xyxy( ...@@ -395,7 +606,7 @@ def _affine_bounding_box_xyxy(
out_bboxes.sub_(tr.repeat((1, 2))) out_bboxes.sub_(tr.repeat((1, 2)))
# Estimate meta-data for image with inverted=True and with center=[0,0] # Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear) affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height) new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
spatial_size = (new_height, new_width) spatial_size = (new_height, new_width)
return out_bboxes.to(bounding_box.dtype), spatial_size return out_bboxes.to(bounding_box.dtype), spatial_size
...@@ -543,18 +754,26 @@ def rotate_image_tensor( ...@@ -543,18 +754,26 @@ def rotate_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0: if image.numel() > 0:
image = _FT.rotate( fp = torch.is_floating_point(image)
image.reshape(-1, num_channels, height, width), image = image.reshape(-1, num_channels, height, width)
matrix,
interpolation=interpolation.value, _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
expand=expand,
fill=fill, ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
) dtype = image.dtype if fp else torch.float32
new_height, new_width = image.shape[-2:] theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
new_height, new_width = output.shape[-2:]
else: else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height) output = image
new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
return image.reshape(shape[:-3] + (num_channels, new_height, new_width)) return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused @torch.jit.unused
...@@ -944,7 +1163,6 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, ...@@ -944,7 +1163,6 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1) # x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1) # y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
# #
# TODO: should we define them transposed?
theta1 = torch.tensor( theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device [[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
) )
...@@ -959,8 +1177,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype, ...@@ -959,8 +1177,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
base_grid[..., 2].fill_(1) base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device)) rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1) shape = (1, oh * ow, 3)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2)) output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
output_grid = output_grid1.div_(output_grid2).sub_(1.0) output_grid = output_grid1.div_(output_grid2).sub_(1.0)
return output_grid.view(1, oh, ow, 2) return output_grid.view(1, oh, ow, 2)
...@@ -996,14 +1215,19 @@ def perspective_image_tensor( ...@@ -996,14 +1215,19 @@ def perspective_image_tensor(
return image return image
shape = image.shape shape = image.shape
ndim = image.ndim
fp = torch.is_floating_point(image)
if image.ndim > 4: if ndim > 4:
image = image.reshape((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else: else:
needs_unsquash = False needs_unsquash = False
_FT._assert_grid_transform_inputs( _assert_grid_transform_inputs(
image, image,
matrix=None, matrix=None,
interpolation=interpolation.value, interpolation=interpolation.value,
...@@ -1012,10 +1236,13 @@ def perspective_image_tensor( ...@@ -1012,10 +1236,13 @@ def perspective_image_tensor(
coeffs=perspective_coeffs, coeffs=perspective_coeffs,
) )
ow, oh = image.shape[-1], image.shape[-2] oh, ow = shape[-2:]
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 dtype = image.dtype if fp else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill) output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash: if needs_unsquash:
output = output.reshape(shape) output = output.reshape(shape)
...@@ -1086,7 +1313,6 @@ def perspective_bounding_box( ...@@ -1086,7 +1313,6 @@ def perspective_bounding_box(
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom, (-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
] ]
# TODO: should we define them transposed?
theta1 = torch.tensor( theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]], [[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype, dtype=dtype,
...@@ -1193,17 +1419,25 @@ def elastic_image_tensor( ...@@ -1193,17 +1419,25 @@ def elastic_image_tensor(
return image return image
shape = image.shape shape = image.shape
ndim = image.ndim
device = image.device device = image.device
fp = torch.is_floating_point(image)
if image.ndim > 4: if ndim > 4:
image = image.reshape((-1,) + shape[-3:]) image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else: else:
needs_unsquash = False needs_unsquash = False
image_height, image_width = shape[-2:] image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device)) grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill) output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash: if needs_unsquash:
output = output.reshape(shape) output = output.reshape(shape)
...@@ -1361,7 +1595,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor ...@@ -1361,7 +1595,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
if crop_height > image_height or crop_width > image_width: if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width) padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0) image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
image_height, image_width = image.shape[-2:] image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height: if crop_width == image_width and crop_height == image_height:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment