Unverified Commit 3130b457 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added functional `rotate_segmentation_mask` op (#5692)

* Added functional affine_bounding_box op with tests

* Updated comments and added another test case

* Update _geometry.py

* Added affine_segmentation_mask with tests

* Fixed device mismatch issue
Added a cude/cpu test
Reduced the number of test samples

* Added test_correctness_affine_segmentation_mask_on_fixed_input

* Updates according to the review

* Replaced [None, ...] by [None, :]

* Adressed review comments

* Fixed formatting and more updates according to the review

* Fixed bad merge

* WIP

* Fixed tests

* Updated warning message
parent 890450a4
...@@ -266,7 +266,7 @@ def affine_bounding_box(): ...@@ -266,7 +266,7 @@ def affine_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask(): def affine_segmentation_mask():
for image, angle, translate, scale, shear in itertools.product( for mask, angle, translate, scale, shear in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))), make_segmentation_masks(extra_dims=((), (4,))),
[-87, 15, 90], # angle [-87, 15, 90], # angle
[5, -5], # translate [5, -5], # translate
...@@ -274,7 +274,7 @@ def affine_segmentation_mask(): ...@@ -274,7 +274,7 @@ def affine_segmentation_mask():
[0, 12], # shear [0, 12], # shear
): ):
yield SampleInput( yield SampleInput(
image, mask,
angle=angle, angle=angle,
translate=(translate, translate), translate=(translate, translate),
scale=scale, scale=scale,
...@@ -285,8 +285,12 @@ def affine_segmentation_mask(): ...@@ -285,8 +285,12 @@ def affine_segmentation_mask():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def rotate_bounding_box(): def rotate_bounding_box():
for bounding_box, angle, expand, center in itertools.product( for bounding_box, angle, expand, center in itertools.product(
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]]
): ):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
continue
yield SampleInput( yield SampleInput(
bounding_box, bounding_box,
format=bounding_box.format, format=bounding_box.format,
...@@ -297,6 +301,26 @@ def rotate_bounding_box(): ...@@ -297,6 +301,26 @@ def rotate_bounding_box():
) )
@register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask():
for mask, angle, expand, center in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
):
if center is not None and expand:
# Skip warning: The provided center argument is ignored if expand is True
continue
yield SampleInput(
mask,
angle=angle,
expand=expand,
center=center,
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel", "kernel",
[ [
...@@ -411,8 +435,9 @@ def test_correctness_affine_bounding_box(angle, translate, scale, shear, center) ...@@ -411,8 +435,9 @@ def test_correctness_affine_bounding_box(angle, translate, scale, shear, center)
center=center, center=center,
) )
if center is None: center_ = center
center = [s // 2 for s in bboxes_image_size[::-1]] if center_ is None:
center_ = [s * 0.5 for s in bboxes_image_size[::-1]]
if bboxes.ndim < 2: if bboxes.ndim < 2:
bboxes = [bboxes] bboxes = [bboxes]
...@@ -421,7 +446,7 @@ def test_correctness_affine_bounding_box(angle, translate, scale, shear, center) ...@@ -421,7 +446,7 @@ def test_correctness_affine_bounding_box(angle, translate, scale, shear, center)
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append( expected_bboxes.append(
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center) _compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center_)
) )
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
...@@ -510,8 +535,10 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce ...@@ -510,8 +535,10 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce
shear=(shear, shear), shear=(shear, shear),
center=center, center=center,
) )
if center is None:
center = [s // 2 for s in mask.shape[-2:][::-1]] center_ = center
if center_ is None:
center_ = [s * 0.5 for s in mask.shape[-2:][::-1]]
if mask.ndim < 4: if mask.ndim < 4:
masks = [mask] masks = [mask]
...@@ -520,7 +547,7 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce ...@@ -520,7 +547,7 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce
expected_masks = [] expected_masks = []
for mask in masks: for mask in masks:
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center) expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center_)
expected_masks.append(expected_mask) expected_masks.append(expected_mask)
if len(expected_masks) > 1: if len(expected_masks) > 1:
expected_masks = torch.stack(expected_masks) expected_masks = torch.stack(expected_masks)
...@@ -550,8 +577,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device): ...@@ -550,8 +577,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("angle", range(-90, 90, 56)) @pytest.mark.parametrize("angle", range(-90, 90, 56))
@pytest.mark.parametrize("expand", [True, False]) @pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
@pytest.mark.parametrize("center", [None, (12, 14)])
def test_correctness_rotate_bounding_box(angle, expand, center): def test_correctness_rotate_bounding_box(angle, expand, center):
def _compute_expected_bbox(bbox, angle_, expand_, center_): def _compute_expected_bbox(bbox, angle_, expand_, center_):
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
...@@ -620,8 +646,9 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -620,8 +646,9 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
center=center, center=center,
) )
if center is None: center_ = center
center = [s // 2 for s in bboxes_image_size[::-1]] if center_ is None:
center_ = [s * 0.5 for s in bboxes_image_size[::-1]]
if bboxes.ndim < 2: if bboxes.ndim < 2:
bboxes = [bboxes] bboxes = [bboxes]
...@@ -629,7 +656,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -629,7 +656,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
expected_bboxes = [] expected_bboxes = []
for bbox in bboxes: for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center)) expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_))
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
...@@ -638,7 +665,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -638,7 +665,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress @pytest.mark.parametrize("expand", [False]) # expand=True does not match D2
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output # Check transformation against known expected output
image_size = (64, 64) image_size = (64, 64)
...@@ -689,3 +716,91 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -689,3 +716,91 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
) )
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes) torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
@pytest.mark.parametrize("angle", range(-90, 90, 37))
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
def test_correctness_rotate_segmentation_mask(angle, expand, center):
def _compute_expected_mask(mask, angle_, expand_, center_):
assert mask.ndim == 3 and mask.shape[0] == 1
image_size = mask.shape[-2:]
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
inv_affine_matrix = np.linalg.inv(affine_matrix)
if expand_:
# Pillow implementation on how to perform expand:
# https://github.com/python-pillow/Pillow/blob/11de3318867e4398057373ee9f12dcb33db7335c/src/PIL/Image.py#L2054-L2069
height, width = image_size
points = np.array(
[
[0.0, 0.0, 1.0],
[0.0, 1.0 * height, 1.0],
[1.0 * width, 1.0 * height, 1.0],
[1.0 * width, 0.0, 1.0],
]
)
new_points = points @ inv_affine_matrix.T
min_vals = np.min(new_points, axis=0)[:2]
max_vals = np.max(new_points, axis=0)[:2]
cmax = np.ceil(np.trunc(max_vals * 1e4) * 1e-4)
cmin = np.floor(np.trunc((min_vals + 1e-8) * 1e4) * 1e-4)
new_width, new_height = (cmax - cmin).astype("int32").tolist()
tr = np.array([-(new_width - width) / 2.0, -(new_height - height) / 2.0, 1.0]) @ inv_affine_matrix.T
inv_affine_matrix[:2, 2] = tr[:2]
image_size = [new_height, new_width]
inv_affine_matrix = inv_affine_matrix[:2, :]
expected_mask = torch.zeros(1, *image_size, dtype=mask.dtype)
for out_y in range(expected_mask.shape[1]):
for out_x in range(expected_mask.shape[2]):
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32)
in_x, in_y = input_pt[:2]
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
return expected_mask.to(mask.device)
for mask in make_segmentation_masks(extra_dims=((), (4,))):
output_mask = F.rotate_segmentation_mask(
mask,
angle=angle,
expand=expand,
center=center,
)
center_ = center
if center_ is None:
center_ = [s * 0.5 for s in mask.shape[-2:][::-1]]
if mask.ndim < 4:
masks = [mask]
else:
masks = [m for m in mask]
expected_masks = []
for mask in masks:
expected_mask = _compute_expected_mask(mask, -angle, expand, center_)
expected_masks.append(expected_mask)
if len(expected_masks) > 1:
expected_masks = torch.stack(expected_masks)
else:
expected_masks = expected_masks[0]
torch.testing.assert_close(output_mask, expected_masks)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
# Check transformation against known expected output and CPU/CUDA devices
# Create a fixed input segmentation mask with 2 square masks
# in top-left, bottom-left corners
mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
mask[0, 2:10, 2:10] = 1
mask[0, 32 - 9 : 32 - 3, 3:9] = 2
# Rotate 90 degrees
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False)
torch.testing.assert_close(out_mask, expected_mask)
...@@ -56,6 +56,7 @@ from ._geometry import ( ...@@ -56,6 +56,7 @@ from ._geometry import (
rotate_bounding_box, rotate_bounding_box,
rotate_image_tensor, rotate_image_tensor,
rotate_image_pil, rotate_image_pil,
rotate_segmentation_mask,
pad_image_tensor, pad_image_tensor,
pad_image_pil, pad_image_pil,
pad_bounding_box, pad_bounding_box,
......
...@@ -324,7 +324,7 @@ def rotate_image_tensor( ...@@ -324,7 +324,7 @@ def rotate_image_tensor(
center_f = [0.0, 0.0] center_f = [0.0, 0.0]
if center is not None: if center is not None:
if expand: if expand:
warnings.warn("The provided center argument is ignored if expand is True") warnings.warn("The provided center argument has no effect on the result if expand is True")
else: else:
_, height, width = get_dimensions_image_tensor(img) _, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
...@@ -345,7 +345,7 @@ def rotate_image_pil( ...@@ -345,7 +345,7 @@ def rotate_image_pil(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> PIL.Image.Image: ) -> PIL.Image.Image:
if center is not None and expand: if center is not None and expand:
warnings.warn("The provided center argument is ignored if expand is True") warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None center = None
return _FP.rotate( return _FP.rotate(
...@@ -361,6 +361,10 @@ def rotate_bounding_box( ...@@ -361,6 +361,10 @@ def rotate_bounding_box(
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if center is not None and expand:
warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_bounding_box_format( bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
...@@ -373,6 +377,21 @@ def rotate_bounding_box( ...@@ -373,6 +377,21 @@ def rotate_bounding_box(
).view(original_shape) ).view(original_shape)
def rotate_segmentation_mask(
img: torch.Tensor,
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return rotate_image_tensor(
img,
angle=angle,
expand=expand,
interpolation=InterpolationMode.NEAREST,
center=center,
)
pad_image_tensor = _FT.pad pad_image_tensor = _FT.pad
pad_image_pil = _FP.pad pad_image_pil = _FP.pad
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment