Unverified Commit 647016bd authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added functional affine_segmentation_mask op (#5613)

* Added functional affine_bounding_box op with tests

* Updated comments and added another test case

* Update _geometry.py

* Added affine_segmentation_mask with tests

* Fixed device mismatch issue
Added a cude/cpu test
Reduced the number of test samples

* Added test_correctness_affine_segmentation_mask_on_fixed_input

* Updates according to the review

* Replaced [None, ...] by [None, :]

* Adressed review comments

* Fixed formatting and more updates according to the review

* Fixed bad merge
parent 65d3a87b
...@@ -138,6 +138,22 @@ def make_one_hot_labels( ...@@ -138,6 +138,22 @@ def make_one_hot_labels(
yield make_one_hot_label(extra_dims_) yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype=torch.long):
size = size or torch.randint(16, 33, (2,)).tolist()
shape = (*extra_dims, 1, *size)
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks(
image_sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.long,),
extra_dims=((), (4,), (2, 3)),
):
for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_)
class SampleInput: class SampleInput:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.args = args self.args = args
...@@ -212,7 +228,7 @@ def resize_bounding_box(): ...@@ -212,7 +228,7 @@ def resize_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def affine_image_tensor(): def affine_image_tensor():
for image, angle, translate, scale, shear in itertools.product( for image, angle, translate, scale, shear in itertools.product(
make_images(extra_dims=()), make_images(extra_dims=((), (4,))),
[-87, 15, 90], # angle [-87, 15, 90], # angle
[5, -5], # translate [5, -5], # translate
[0.77, 1.27], # scale [0.77, 1.27], # scale
...@@ -248,6 +264,24 @@ def affine_bounding_box(): ...@@ -248,6 +264,24 @@ def affine_bounding_box():
) )
@register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask():
for image, angle, translate, scale, shear in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
[0, 12], # shear
):
yield SampleInput(
image,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
)
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def rotate_bounding_box(): def rotate_bounding_box():
for bounding_box, angle, expand, center in itertools.product( for bounding_box, angle, expand, center in itertools.product(
...@@ -444,6 +478,76 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -444,6 +478,76 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)
@pytest.mark.parametrize("angle", [-54, 56])
@pytest.mark.parametrize("translate", [-7, 8])
@pytest.mark.parametrize("scale", [0.89, 1.12])
@pytest.mark.parametrize("shear", [4])
@pytest.mark.parametrize("center", [None, (12, 14)])
def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, center):
def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
assert mask.ndim == 3 and mask.shape[0] == 1
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
inv_affine_matrix = np.linalg.inv(affine_matrix)
inv_affine_matrix = inv_affine_matrix[:2, :]
expected_mask = torch.zeros_like(mask.cpu())
for out_y in range(expected_mask.shape[1]):
for out_x in range(expected_mask.shape[2]):
output_pt = np.array([out_x + 0.5, out_y + 0.5, 1.0])
input_pt = np.floor(np.dot(inv_affine_matrix, output_pt)).astype(np.int32)
in_x, in_y = input_pt[:2]
if 0 <= in_x < mask.shape[2] and 0 <= in_y < mask.shape[1]:
expected_mask[0, out_y, out_x] = mask[0, in_y, in_x]
return expected_mask.to(mask.device)
for mask in make_segmentation_masks(extra_dims=((), (4,))):
output_mask = F.affine_segmentation_mask(
mask,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
center=center,
)
if center is None:
center = [s // 2 for s in mask.shape[-2:][::-1]]
if mask.ndim < 4:
masks = [mask]
else:
masks = [m for m in mask]
expected_masks = []
for mask in masks:
expected_mask = _compute_expected_mask(mask, angle, (translate, translate), scale, (shear, shear), center)
expected_masks.append(expected_mask)
if len(expected_masks) > 1:
expected_masks = torch.stack(expected_masks)
else:
expected_masks = expected_masks[0]
torch.testing.assert_close(output_mask, expected_masks)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_affine_segmentation_mask_on_fixed_input(device):
# Check transformation against known expected output and CPU/CUDA devices
# Create a fixed input segmentation mask with 2 square masks
# in top-left, bottom-left corners
mask = torch.zeros(1, 32, 32, dtype=torch.long, device=device)
mask[0, 2:10, 2:10] = 1
mask[0, 32 - 9 : 32 - 3, 3:9] = 2
# Rotate 90 degrees and scale
expected_mask = torch.rot90(mask, k=-1, dims=(-2, -1))
expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest")
expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long()
out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0])
torch.testing.assert_close(out_mask, expected_mask)
@pytest.mark.parametrize("angle", range(-90, 90, 56)) @pytest.mark.parametrize("angle", range(-90, 90, 56))
@pytest.mark.parametrize("expand", [True, False]) @pytest.mark.parametrize("expand", [True, False])
@pytest.mark.parametrize("center", [None, (12, 14)]) @pytest.mark.parametrize("center", [None, (12, 14)])
......
...@@ -52,6 +52,7 @@ from ._geometry import ( ...@@ -52,6 +52,7 @@ from ._geometry import (
affine_bounding_box, affine_bounding_box,
affine_image_tensor, affine_image_tensor,
affine_image_pil, affine_image_pil,
affine_segmentation_mask,
rotate_bounding_box, rotate_bounding_box,
rotate_image_tensor, rotate_image_tensor,
rotate_image_pil, rotate_image_pil,
......
...@@ -294,6 +294,25 @@ def affine_bounding_box( ...@@ -294,6 +294,25 @@ def affine_bounding_box(
).view(original_shape) ).view(original_shape)
def affine_segmentation_mask(
img: torch.Tensor,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
img,
angle=angle,
translate=translate,
scale=scale,
shear=shear,
interpolation=InterpolationMode.NEAREST,
center=center,
)
def rotate_image_tensor( def rotate_image_tensor(
img: torch.Tensor, img: torch.Tensor,
angle: float, angle: float,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment