Unverified Commit 112accf9 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup prototype kernels for degenerate inputs (#6544)

* avoid double padding parsing

* remove cloning in degenerate case

* fix affine and rotate for degenerate inputs

* fix rotate for degenerate inputs if expand=True
parent 84dcf695
......@@ -86,6 +86,8 @@ class TestSmoke:
transforms.RandomHorizontalFlip(),
transforms.Pad(5),
transforms.RandomZoomOut(),
transforms.RandomRotation(degrees=(-45, 45)),
transforms.RandomAffine(degrees=(-45, 45)),
transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True),
# TODO: Something wrong with input data setup. Let's fix that
# transforms.RandomEqualize(),
......@@ -93,8 +95,6 @@ class TestSmoke:
# transforms.RandomPosterize(bits=4),
# transforms.RandomSolarize(threshold=0.5),
# transforms.RandomAdjustSharpness(sharpness_factor=0.5),
# transforms.RandomRotation(degrees=(-45, 45)),
# transforms.RandomAffine(degrees=(-45, 45)),
)
def test_common(self, transform, input):
transform(input)
......
......@@ -290,7 +290,7 @@ def resize_segmentation_mask():
@register_kernel_info_from_sample_inputs_fn
def affine_image_tensor():
for image, angle, translate, scale, shear in itertools.product(
make_images(extra_dims=((), (4,))),
make_images(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
......@@ -329,7 +329,7 @@ def affine_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask():
for mask, angle, translate, scale, shear in itertools.product(
make_segmentation_masks(extra_dims=((), (4,)), num_objects=[10]),
make_segmentation_masks(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
......@@ -347,7 +347,7 @@ def affine_segmentation_mask():
@register_kernel_info_from_sample_inputs_fn
def rotate_image_tensor():
for image, angle, expand, center, fill in itertools.product(
make_images(extra_dims=((), (4,))),
make_images(),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
......@@ -382,7 +382,7 @@ def rotate_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask():
for mask, angle, expand, center in itertools.product(
make_segmentation_masks(extra_dims=((), (4,)), num_objects=[10]),
make_segmentation_masks(),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
......
......@@ -108,17 +108,14 @@ def resize_image_tensor(
extra_dims = image.shape[:-3]
if image.numel() > 0:
resized_image = _FT.resize(
image = _FT.resize(
image.view(-1, num_channels, old_height, old_width),
size=[new_height, new_width],
interpolation=interpolation.value,
antialias=antialias,
)
else:
# TODO: the cloning is probably unnecessary. Review this together with the other perf candidates
resized_image = image.clone()
return resized_image.view(extra_dims + (num_channels, new_height, new_width))
return image.view(extra_dims + (num_channels, new_height, new_width))
def resize_image_pil(
......@@ -229,6 +226,9 @@ def affine_image_tensor(
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if img.numel() == 0:
return img
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
img = img.view(-1, num_channels, height, width)
......@@ -452,23 +452,32 @@ def rotate_image_tensor(
) -> torch.Tensor:
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
img = img.view(-1, num_channels, height, width)
center_f = [0.0, 0.0]
if center is not None:
if expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
else:
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
matrix = _get_inverse_affine_matrix(center_f, -angle, [0.0, 0.0], 1.0, [0.0, 0.0])
output = _FT.rotate(img, matrix, interpolation=interpolation.value, expand=expand, fill=fill)
new_height, new_width = output.shape[-2:]
return output.view(extra_dims + (num_channels, new_height, new_width))
if img.numel() > 0:
img = _FT.rotate(
img.view(-1, num_channels, height, width),
matrix,
interpolation=interpolation.value,
expand=expand,
fill=fill,
)
new_height, new_width = img.shape[-2:]
else:
new_width, new_height = _FT._compute_affine_output_size(matrix, width, height) if expand else (width, height)
return img.view(extra_dims + (num_channels, new_height, new_width))
def rotate_image_pil(
......@@ -557,19 +566,17 @@ def pad_image_tensor(
num_channels, height, width = img.shape[-3:]
extra_dims = img.shape[:-3]
left, right, top, bottom = _FT._parse_pad_padding(padding)
new_height = height + top + bottom
new_width = width + left + right
if img.numel() > 0:
padded_image = _FT.pad(
img = _FT.pad(
img=img.view(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
)
new_height, new_width = img.shape[-2:]
else:
# TODO: the cloning is probably unnecessary. Review this together with the other perf candidates
padded_image = img.clone()
left, right, top, bottom = _FT._parse_pad_padding(padding)
new_height = height + top + bottom
new_width = width + left + right
return padded_image.view(extra_dims + (num_channels, new_height, new_width))
return img.view(extra_dims + (num_channels, new_height, new_width))
# TODO: This should be removed once pytorch pad supports non-scalar padding values
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment