"src/array/cuda/array_cumsum.hip" did not exist on "1547bd931d17cd1da144a6d38bb687c0f2c3b364"
Unverified Commit b1f6c9e2 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

[prototype] Optimize and clean up all affine methods (#6945)

* Clean up `_get_inverse_affine_matrix` and `_compute_affine_output_size`

* Optimize `_apply_grid_transform`

* Cleanup `_assert_grid_transform_inputs`

* Fix bugs on `_pad_with_scalar_fill` & `crop_mask` and port `crop_image_tensor`

* Call directly `_pad_with_scalar_fill`

* Fix linter

* Clean up `center_crop_image_tensor`

* Fix comments.

* Fixing rounding issues.

* Bumping tolerance for rotate which is unrelated to this PR.

* Fix tolerance threshold for RandomPerspective.

* Clean up `_affine_grid` and `affine_image_tensor`

* Clean up `rotate_image_tensor`

* Fixing linter

* Address code-review comments.
parent de350bc0
......@@ -915,7 +915,7 @@ KERNEL_INFOS.extend(
reference_inputs_fn=reference_inputs_rotate_image_tensor,
float32_vs_uint8=True,
# TODO: investigate
closeness_kwargs=pil_reference_pixel_difference(100, agg_method="mean"),
closeness_kwargs=pil_reference_pixel_difference(110, agg_method="mean"),
test_marks=[
xfail_jit_tuple_instead_of_list("fill"),
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
......
......@@ -401,6 +401,7 @@ CONSISTENCY_CONFIGS = [
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
],
closeness_kwargs={"atol": None, "rtol": None},
),
ConsistencyConfig(
prototype_transforms.RandomRotation,
......
import math
import numbers
import warnings
from typing import List, Optional, Sequence, Tuple, Union
import PIL.Image
import torch
from torch.nn.functional import interpolate, pad as torch_pad
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
from torchvision.transforms.functional import (
_compute_resized_output_size as __compute_resized_output_size,
_get_inverse_affine_matrix,
_get_perspective_coeffs,
InterpolationMode,
pil_modes_mapping,
......@@ -272,6 +272,195 @@ def _affine_parse_args(
return angle, translate, shear, center
def _get_inverse_affine_matrix(
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
) -> List[float]:
# Helper method to compute inverse matrix for affine transformation
# Pillow requires inverse affine transformation matrix:
# Affine matrix is : M = T * C * RotateScaleShear * C^-1
#
# where T is translation matrix: [1, 0, tx | 0, 1, ty | 0, 0, 1]
# C is translation matrix to keep center: [1, 0, cx | 0, 1, cy | 0, 0, 1]
# RotateScaleShear is rotation with scale and shear matrix
#
# RotateScaleShear(a, s, (sx, sy)) =
# = R(a) * S(s) * SHy(sy) * SHx(sx)
# = [ s*cos(a - sy)/cos(sy), s*(-cos(a - sy)*tan(sx)/cos(sy) - sin(a)), 0 ]
# [ s*sin(a - sy)/cos(sy), s*(-sin(a - sy)*tan(sx)/cos(sy) + cos(a)), 0 ]
# [ 0 , 0 , 1 ]
# where R is a rotation matrix, S is a scaling matrix, and SHx and SHy are the shears:
# SHx(s) = [1, -tan(s)] and SHy(s) = [1 , 0]
# [0, 1 ] [-tan(s), 1]
#
# Thus, the inverse is M^-1 = C * RotateScaleShear^-1 * C^-1 * T^-1
rot = math.radians(angle)
sx = math.radians(shear[0])
sy = math.radians(shear[1])
cx, cy = center
tx, ty = translate
# Cached results
cos_sy = math.cos(sy)
tan_sx = math.tan(sx)
rot_minus_sy = rot - sy
cx_plus_tx = cx + tx
cy_plus_ty = cy + ty
# Rotate Scale Shear (RSS) without scaling
a = math.cos(rot_minus_sy) / cos_sy
b = -(a * tan_sx + math.sin(rot))
c = math.sin(rot_minus_sy) / cos_sy
d = math.cos(rot) - c * tan_sx
if inverted:
# Inverted rotation matrix with scale and shear
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
matrix = [d / scale, -b / scale, 0.0, -c / scale, a / scale, 0.0]
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
# and then apply center translation: C * RSS^-1 * C^-1 * T^-1
matrix[2] += cx - matrix[0] * cx_plus_tx - matrix[1] * cy_plus_ty
matrix[5] += cy - matrix[3] * cx_plus_tx - matrix[4] * cy_plus_ty
else:
matrix = [a * scale, b * scale, 0.0, c * scale, d * scale, 0.0]
# Apply inverse of center translation: RSS * C^-1
# and then apply translation and center : T * C * RSS * C^-1
matrix[2] += cx_plus_tx - matrix[0] * cx - matrix[1] * cy
matrix[5] += cy_plus_ty - matrix[3] * cx - matrix[4] * cy
return matrix
def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]:
# Inspired of PIL implementation:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054
# pts are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
# Points are shifted due to affine matrix torch convention about
# the center point. Center is (0, 0) for image center pivot point (w * 0.5, h * 0.5)
half_w = 0.5 * w
half_h = 0.5 * h
pts = torch.tensor(
[
[-half_w, -half_h, 1.0],
[-half_w, half_h, 1.0],
[half_w, half_h, 1.0],
[half_w, -half_h, 1.0],
]
)
theta = torch.tensor(matrix, dtype=torch.float).view(2, 3)
new_pts = torch.matmul(pts, theta.T)
min_vals, max_vals = new_pts.aminmax(dim=0)
# shift points to [0, w] and [0, h] interval to match PIL results
halfs = torch.tensor((half_w, half_h))
min_vals.add_(halfs)
max_vals.add_(halfs)
# Truncate precision to 1e-4 to avoid ceil of Xe-15 to 1.0
tol = 1e-4
inv_tol = 1.0 / tol
cmax = max_vals.mul_(inv_tol).trunc_().mul_(tol).ceil_()
cmin = min_vals.mul_(inv_tol).trunc_().mul_(tol).floor_()
size = cmax.sub_(cmin)
return int(size[0]), int(size[1]) # w, h
def _apply_grid_transform(
float_img: torch.Tensor, grid: torch.Tensor, mode: str, fill: features.FillTypeJIT
) -> torch.Tensor:
shape = float_img.shape
if shape[0] > 1:
# Apply same grid to a batch of images
grid = grid.expand(shape[0], -1, -1, -1)
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
float_img = torch.cat((float_img, mask), dim=1)
float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
# Fill with required color
if fill is not None:
float_img, mask = torch.tensor_split(float_img, indices=(-1,), dim=-3)
mask = mask.expand_as(float_img)
fill_list = fill if isinstance(fill, (tuple, list)) else [float(fill)]
fill_img = torch.tensor(fill_list, dtype=float_img.dtype, device=float_img.device).view(1, -1, 1, 1)
if mode == "nearest":
bool_mask = mask < 0.5
float_img[bool_mask] = fill_img.expand_as(float_img)[bool_mask]
else: # 'bilinear'
# The following is mathematically equivalent to:
# img * mask + (1.0 - mask) * fill = img * mask - fill * mask + fill = mask * (img - fill) + fill
float_img = float_img.sub_(fill_img).mul_(mask).add_(fill_img)
return float_img
def _assert_grid_transform_inputs(
image: torch.Tensor,
matrix: Optional[List[float]],
interpolation: str,
fill: features.FillTypeJIT,
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
) -> None:
if matrix is not None:
if not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
elif len(matrix) != 6:
raise ValueError("Argument matrix should have 6 float values")
if coeffs is not None and len(coeffs) != 8:
raise ValueError("Argument coeffs should have 8 float values")
if fill is not None:
if isinstance(fill, (tuple, list)):
length = len(fill)
num_channels = image.shape[-3]
if length > 1 and length != num_channels:
raise ValueError(
"The number of elements in 'fill' cannot broadcast to match the number of "
f"channels of the image ({length} != {num_channels})"
)
elif not isinstance(fill, (int, float)):
raise ValueError("Argument fill should be either int, float, tuple or list")
if interpolation not in supported_interpolation_modes:
raise ValueError(f"Interpolation mode '{interpolation}' is unsupported with Tensor input")
def _affine_grid(
theta: torch.Tensor,
w: int,
h: int,
ow: int,
oh: int,
) -> torch.Tensor:
# https://github.com/pytorch/pytorch/blob/74b65c32be68b15dc7c9e8bb62459efbfbde33d8/aten/src/ATen/native/
# AffineGridGenerator.cpp#L18
# Difference with AffineGridGenerator is that:
# 1) we normalize grid values after applying theta
# 2) we can normalize by other image size, such that it covers "extend" option like in PIL.Image.rotate
dtype = theta.dtype
device = theta.device
base_grid = torch.empty(1, oh, ow, 3, dtype=dtype, device=device)
x_grid = torch.linspace((1.0 - ow) * 0.5, (ow - 1.0) * 0.5, steps=ow, device=device)
base_grid[..., 0].copy_(x_grid)
y_grid = torch.linspace((1.0 - oh) * 0.5, (oh - 1.0) * 0.5, steps=oh, device=device).unsqueeze_(-1)
base_grid[..., 1].copy_(y_grid)
base_grid[..., 2].fill_(1)
rescaled_theta = theta.transpose(1, 2).div_(torch.tensor([0.5 * w, 0.5 * h], dtype=dtype, device=device))
output_grid = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta)
return output_grid.view(1, oh, ow, 2)
def affine_image_tensor(
image: torch.Tensor,
angle: Union[int, float],
......@@ -286,9 +475,19 @@ def affine_image_tensor(
return image
shape = image.shape
num_channels, height, width = shape[-3:]
image = image.reshape(-1, num_channels, height, width)
ndim = image.ndim
fp = torch.is_floating_point(image)
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
height, width = shape[-2:]
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
center_f = [0.0, 0.0]
......@@ -299,8 +498,20 @@ def affine_image_tensor(
translate_f = [float(t) for t in translate]
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
output = _FT.affine(image, matrix, interpolation=interpolation.value, fill=fill)
return output.reshape(shape)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
dtype = image.dtype if fp else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash:
output = output.reshape(shape)
return output
@torch.jit.unused
......@@ -395,7 +606,7 @@ def _affine_bounding_box_xyxy(
out_bboxes.sub_(tr.repeat((1, 2)))
# Estimate meta-data for image with inverted=True and with center=[0,0]
affine_vector = _get_inverse_affine_matrix([0.0, 0.0], angle, translate, scale, shear)
new_width, new_height = _FT._compute_affine_output_size(affine_vector, width, height)
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
spatial_size = (new_height, new_width)
return out_bboxes.to(bounding_box.dtype), spatial_size
......@@ -543,18 +754,26 @@ def rotate_image_tensor(
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0:
image = _FT.rotate(
image.reshape(-1, num_channels, height, width),
matrix,
interpolation=interpolation.value,
expand=expand,
fill=fill,
)
new_height, new_width = image.shape[-2:]
fp = torch.is_floating_point(image)
image = image.reshape(-1, num_channels, height, width)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
dtype = image.dtype if fp else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
new_height, new_width = output.shape[-2:]
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
output = image
new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
@torch.jit.unused
......@@ -944,7 +1163,6 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
# x_out = (coeffs[0] * x + coeffs[1] * y + coeffs[2]) / (coeffs[6] * x + coeffs[7] * y + 1)
# y_out = (coeffs[3] * x + coeffs[4] * y + coeffs[5]) / (coeffs[6] * x + coeffs[7] * y + 1)
#
# TODO: should we define them transposed?
theta1 = torch.tensor(
[[[coeffs[0], coeffs[1], coeffs[2]], [coeffs[3], coeffs[4], coeffs[5]]]], dtype=dtype, device=device
)
......@@ -959,8 +1177,9 @@ def _perspective_grid(coeffs: List[float], ow: int, oh: int, dtype: torch.dtype,
base_grid[..., 2].fill_(1)
rescaled_theta1 = theta1.transpose(1, 2).div_(torch.tensor([0.5 * ow, 0.5 * oh], dtype=dtype, device=device))
output_grid1 = base_grid.view(1, oh * ow, 3).bmm(rescaled_theta1)
output_grid2 = base_grid.view(1, oh * ow, 3).bmm(theta2.transpose(1, 2))
shape = (1, oh * ow, 3)
output_grid1 = base_grid.view(shape).bmm(rescaled_theta1)
output_grid2 = base_grid.view(shape).bmm(theta2.transpose(1, 2))
output_grid = output_grid1.div_(output_grid2).sub_(1.0)
return output_grid.view(1, oh, ow, 2)
......@@ -996,14 +1215,19 @@ def perspective_image_tensor(
return image
shape = image.shape
ndim = image.ndim
fp = torch.is_floating_point(image)
if image.ndim > 4:
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
_FT._assert_grid_transform_inputs(
_assert_grid_transform_inputs(
image,
matrix=None,
interpolation=interpolation.value,
......@@ -1012,10 +1236,13 @@ def perspective_image_tensor(
coeffs=perspective_coeffs,
)
ow, oh = image.shape[-1], image.shape[-2]
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
oh, ow = shape[-2:]
dtype = image.dtype if fp else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill=fill)
output = _apply_grid_transform(image if fp else image.to(dtype), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash:
output = output.reshape(shape)
......@@ -1086,7 +1313,6 @@ def perspective_bounding_box(
(-perspective_coeffs[0] * perspective_coeffs[7] + perspective_coeffs[1] * perspective_coeffs[6]) / denom,
]
# TODO: should we define them transposed?
theta1 = torch.tensor(
[[inv_coeffs[0], inv_coeffs[1], inv_coeffs[2]], [inv_coeffs[3], inv_coeffs[4], inv_coeffs[5]]],
dtype=dtype,
......@@ -1193,17 +1419,25 @@ def elastic_image_tensor(
return image
shape = image.shape
ndim = image.ndim
device = image.device
fp = torch.is_floating_point(image)
if image.ndim > 4:
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device).add_(displacement.to(device))
output = _FT._apply_grid_transform(image, grid, interpolation.value, fill)
output = _apply_grid_transform(image if fp else image.to(torch.float32), grid, interpolation.value, fill=fill)
if not fp:
output = output.round_().to(image.dtype)
if needs_unsquash:
output = output.reshape(shape)
......@@ -1361,7 +1595,7 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
if crop_height > image_height or crop_width > image_width:
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
image = _FT.torch_pad(image, _FT._parse_pad_padding(padding_ltrb), value=0.0)
image = torch_pad(image, _parse_pad_padding(padding_ltrb), value=0.0)
image_height, image_width = image.shape[-2:]
if crop_width == image_width and crop_height == image_height:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment