"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "4d2fa1908e531c6c815c026533b0e51a10ef9aef"
Unverified Commit ee28bb3c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

cleanup affine grid image kernels (#8004)

parent f96deba0
...@@ -2491,7 +2491,7 @@ class TestElastic: ...@@ -2491,7 +2491,7 @@ class TestElastic:
interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR], interpolation=[transforms.InterpolationMode.NEAREST, transforms.InterpolationMode.BILINEAR],
fill=EXHAUSTIVE_TYPE_FILLS, fill=EXHAUSTIVE_TYPE_FILLS,
) )
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8, torch.float16])
@pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("device", cpu_and_cuda())
def test_kernel_image(self, param, value, dtype, device): def test_kernel_image(self, param, value, dtype, device):
image = make_image_tensor(dtype=dtype, device=device) image = make_image_tensor(dtype=dtype, device=device)
...@@ -2502,6 +2502,7 @@ class TestElastic: ...@@ -2502,6 +2502,7 @@ class TestElastic:
displacement=self._make_displacement(image), displacement=self._make_displacement(image),
**{param: value}, **{param: value},
check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))), check_scripted_vs_eager=not (param == "fill" and isinstance(value, (int, float))),
check_cuda_vs_cpu=dtype is not torch.float16,
) )
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat)) @pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
......
...@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in ...@@ -551,19 +551,30 @@ def _compute_affine_output_size(matrix: List[float], w: int, h: int) -> Tuple[in
def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor: def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill: _FillTypeJIT) -> torch.Tensor:
input_shape = img.shape
output_height, output_width = grid.shape[1], grid.shape[2]
num_channels, input_height, input_width = input_shape[-3:]
output_shape = input_shape[:-3] + (num_channels, output_height, output_width)
if img.numel() == 0:
return img.reshape(output_shape)
img = img.reshape(-1, num_channels, input_height, input_width)
squashed_batch_size = img.shape[0]
# We are using context knowledge that grid should have float dtype # We are using context knowledge that grid should have float dtype
fp = img.dtype == grid.dtype fp = img.dtype == grid.dtype
float_img = img if fp else img.to(grid.dtype) float_img = img if fp else img.to(grid.dtype)
shape = float_img.shape if squashed_batch_size > 1:
if shape[0] > 1:
# Apply same grid to a batch of images # Apply same grid to a batch of images
grid = grid.expand(shape[0], -1, -1, -1) grid = grid.expand(squashed_batch_size, -1, -1, -1)
# Append a dummy mask for customized fill colors, should be faster than grid_sample() twice # Append a dummy mask for customized fill colors, should be faster than grid_sample() twice
if fill is not None: if fill is not None:
mask = torch.ones((shape[0], 1, shape[2], shape[3]), dtype=float_img.dtype, device=float_img.device) mask = torch.ones(
(squashed_batch_size, 1, input_height, input_width), dtype=float_img.dtype, device=float_img.device
)
float_img = torch.cat((float_img, mask), dim=1) float_img = torch.cat((float_img, mask), dim=1)
float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False) float_img = grid_sample(float_img, grid, mode=mode, padding_mode="zeros", align_corners=False)
...@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill ...@@ -584,7 +595,7 @@ def _apply_grid_transform(img: torch.Tensor, grid: torch.Tensor, mode: str, fill
img = float_img.round_().to(img.dtype) if not fp else float_img img = float_img.round_().to(img.dtype) if not fp else float_img
return img return img.reshape(output_shape)
def _assert_grid_transform_inputs( def _assert_grid_transform_inputs(
...@@ -661,24 +672,10 @@ def affine_image( ...@@ -661,24 +672,10 @@ def affine_image(
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
height, width = shape[-2:]
angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center) angle, translate, shear, center = _affine_parse_args(angle, translate, scale, shear, interpolation, center)
height, width = image.shape[-2:]
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
...@@ -692,12 +689,7 @@ def affine_image( ...@@ -692,12 +689,7 @@ def affine_image(
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3) theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height) grid = _affine_grid(theta, w=width, h=height, ow=width, oh=height)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
return output
@_register_kernel_internal(affine, PIL.Image.Image) @_register_kernel_internal(affine, PIL.Image.Image)
...@@ -969,35 +961,26 @@ def rotate_image( ...@@ -969,35 +961,26 @@ def rotate_image(
) -> torch.Tensor: ) -> torch.Tensor:
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
shape = image.shape input_height, input_width = image.shape[-2:]
num_channels, height, width = shape[-3:]
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [(c - s * 0.5) for c, s in zip(center, [width, height])] center_f = [(c - s * 0.5) for c, s in zip(center, [input_width, input_height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations # due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle. # we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0]) matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
if image.numel() > 0: _assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
image = image.reshape(-1, num_channels, height, width)
_assert_grid_transform_inputs(image, matrix, interpolation.value, fill, ["nearest", "bilinear"])
ow, oh = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=width, h=height, ow=ow, oh=oh)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
new_height, new_width = output.shape[-2:]
else:
output = image
new_width, new_height = _compute_affine_output_size(matrix, width, height) if expand else (width, height)
return output.reshape(shape[:-3] + (num_channels, new_height, new_width)) output_width, output_height = (
_compute_affine_output_size(matrix, input_width, input_height) if expand else (input_width, input_height)
)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
theta = torch.tensor(matrix, dtype=dtype, device=image.device).reshape(1, 2, 3)
grid = _affine_grid(theta, w=input_width, h=input_height, ow=output_width, oh=output_height)
return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
@_register_kernel_internal(rotate, PIL.Image.Image) @_register_kernel_internal(rotate, PIL.Image.Image)
...@@ -1509,21 +1492,6 @@ def perspective_image( ...@@ -1509,21 +1492,6 @@ def perspective_image(
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
if image.numel() == 0:
return image
shape = image.shape
ndim = image.ndim
if ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
_assert_grid_transform_inputs( _assert_grid_transform_inputs(
image, image,
matrix=None, matrix=None,
...@@ -1533,15 +1501,10 @@ def perspective_image( ...@@ -1533,15 +1501,10 @@ def perspective_image(
coeffs=perspective_coeffs, coeffs=perspective_coeffs,
) )
oh, ow = shape[-2:] oh, ow = image.shape[-2:]
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device) grid = _perspective_grid(perspective_coeffs, ow=ow, oh=oh, dtype=dtype, device=image.device)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) return _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
return output
@_register_kernel_internal(perspective, PIL.Image.Image) @_register_kernel_internal(perspective, PIL.Image.Image)
...@@ -1759,12 +1722,7 @@ def elastic_image( ...@@ -1759,12 +1722,7 @@ def elastic_image(
interpolation = _check_interpolation(interpolation) interpolation = _check_interpolation(interpolation)
if image.numel() == 0: height, width = image.shape[-2:]
return image
shape = image.shape
ndim = image.ndim
device = image.device device = image.device
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 dtype = image.dtype if torch.is_floating_point(image) else torch.float32
...@@ -1775,32 +1733,18 @@ def elastic_image( ...@@ -1775,32 +1733,18 @@ def elastic_image(
dtype = torch.float32 dtype = torch.float32
# We are aware that if input image dtype is uint8 and displacement is float64 then # We are aware that if input image dtype is uint8 and displacement is float64 then
# displacement will be casted to float32 and all computations will be done with float32 # displacement will be cast to float32 and all computations will be done with float32
# We can fix this later if needed # We can fix this later if needed
expected_shape = (1,) + shape[-2:] + (2,) expected_shape = (1, height, width, 2)
if expected_shape != displacement.shape: if expected_shape != displacement.shape:
raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}") raise ValueError(f"Argument displacement shape should be {expected_shape}, but given {displacement.shape}")
if ndim > 4: grid = _create_identity_grid((height, width), device=device, dtype=dtype).add_(
image = image.reshape((-1,) + shape[-3:]) displacement.to(dtype=dtype, device=device)
needs_unsquash = True )
elif ndim == 3:
image = image.unsqueeze(0)
needs_unsquash = True
else:
needs_unsquash = False
if displacement.dtype != dtype or displacement.device != device:
displacement = displacement.to(dtype=dtype, device=device)
image_height, image_width = shape[-2:]
grid = _create_identity_grid((image_height, image_width), device=device, dtype=dtype).add_(displacement)
output = _apply_grid_transform(image, grid, interpolation.value, fill=fill) output = _apply_grid_transform(image, grid, interpolation.value, fill=fill)
if needs_unsquash:
output = output.reshape(shape)
if is_cpu_half: if is_cpu_half:
output = output.to(torch.float16) output = output.to(torch.float16)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment