Unverified Commit ee28bb3c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup affine grid image kernels (#8004)

parent f96deba0
......@@ -2491,7 +2491,7 @@ class TestElastic:
interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
fill=EXHAUSTIVE_TYPE_FILLS,
)
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8, torch.float16])
@pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, param, value, dtype, device):
image = make_image_tensor(dtype=dtype, device=device)
......@@ -2502,6 +2502,7 @@ class TestElastic:
displacement=self._make_displacement(image),
**{param: value},
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
check_cuda_vs_cpu=dtype is not torch.float16,
)
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
......
......@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
input_shape = img.shape
output_height, output_width = grid.shape[1], grid.shape[2]
num_channels, input_height, input_width = input_shape[-3:]
output_shape = input_shape[:-3] + (num_channels, output_height, output_width)
if img.numel() == 0:
return img.reshape(output_shape)
img = img.reshape(-1, num_channels, input_height, input_width)
squashed_batch_size = img.shape[0]
# We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype
float_img = img if fp else img.to(grid.dtype)
shape = float_img.shape
if shape[0] > 1:
if squashed_batch_size > 1:
# Apply same grid to a batch of images
grid = grid.expand(shape[0], -1, -1, -1)
grid = grid.expand(squashed_batch_size, -1, -1, -1)
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None:
mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device)
mask = torch.ones(
(squashed_batch_size, 1, input_height, input_width), dtype=float_img.dtype, device=float_img.device
)
float_img = torch.cat((float_img, mask), dim=1)
float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
......@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
img = float_img.round_().to(img.dtype) if not fp else float_img
return img
return img.reshape(output_shape)
def _assert_grid_transform_inputs(
......@@ -661,24 +672,10 @@ def affine_image(
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
height, width = shape[-2:]
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
height, width = image.shape[-2:]
center_f = [0.0, 0.0]
if center is not None:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
......@@ -692,12 +689,7 @@ def affine_image(
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
return output
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
@_register_kernel_internal(affine, PIL.Image.Image)
......@@ -969,35 +961,26 @@ def rotate_image(
) -> torch.Tensor:
interpolation = _check_interpolation(interpolation)
shape = image.shape
num_channels, height, width = shape[-3:]
input_height, input_width = image.shape[-2:]
center_f = [0.0, 0.0]
if center is not None:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])]
center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0:
image = image.reshape(-1, num_channels, height, width)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
new_height, new_width = output.shape[-2:]
else:
output = image
new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
return output.reshape(shape[:-3] + (num_channels, new_height, new_width))
output_width, output_height = (
_compute_affine_output_size(matrix, input_width, input_height) if expand else (input_width, input_height)
)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=input_width, h=input_height, ow=output_width, oh=output_height)
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
@_register_kernel_internal(rotate, PIL.Image.Image)
......@@ -1509,21 +1492,6 @@ def perspective_image(
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
_assert_grid_transform_inputs(
image,
matrix=None,
......@@ -1533,15 +1501,10 @@ def perspective_image(
coeffs=perspective_coeffs,
)
oh, ow = shape[-2:]
oh, ow = image.shape[-2:]
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
return output
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
@_register_kernel_internal(perspective, PIL.Image.Image)
......@@ -1759,12 +1722,7 @@ def elastic_image(
interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
height, width = image.shape[-2:]
device = image.device
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
......@@ -1775,32 +1733,18 @@ def elastic_image(
dtype = torch.float32
# We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32
# displacement will be cast to float32 and all computations will be done with float32
# We can fix this later if needed
expected_shape = (1,) + shape[-2:] + (2,)
expected_shape = (1, height, width, 2)
if expected_shape != displacement.shape:
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_(
displacement.to(dtype=dtype, device=device)
)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
if is_cpu_half:
output = output.to(torch.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment