Unverified Commit de31e4b8 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added `resized_crop_segmentation_mask` op (#5855)

* [proto] Added crop_bounding_box op

* Added `crop_segmentation_mask` op

* Fixed failed mypy

* Added tests for resized_crop_bounding_box

* Fixed code formatting

* Added resized_crop_segmentation_mask op

* Added tests
parent 6d85d74b
...@@ -362,6 +362,14 @@ def resized_crop_bounding_box(): ...@@ -362,6 +362,14 @@ def resized_crop_bounding_box():
) )
@register_kernel_info_from_sample_inputs_fn
def resized_crop_segmentation_mask():
for mask, top, left, height, width, size in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
):
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel", "kernel",
[ [
...@@ -998,3 +1006,28 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -998,3 +1006,28 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY) output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY)
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"top, left, height, width, size",
[
[0, 0, 30, 30, (60, 60)],
[5, 5, 35, 45, (32, 34)],
],
)
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size):
def _compute_expected(mask, top_, left_, height_, width_, size_):
output = mask.clone()
output = output[:, top_ : top_ + height_, left_ : left_ + width_]
output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest")
output = output[0, :].long()
return output
in_mask = torch.zeros(1, 100, 100, dtype=torch.long, device=device)
in_mask[0, 10:20, 10:20] = 1
in_mask[0, 5:15, 12:23] = 2
expected_mask = _compute_expected(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
torch.testing.assert_close(output_mask, expected_mask)
...@@ -47,9 +47,10 @@ from ._geometry import ( ...@@ -47,9 +47,10 @@ from ._geometry import (
resize_segmentation_mask, resize_segmentation_mask,
center_crop_image_tensor, center_crop_image_tensor,
center_crop_image_pil, center_crop_image_pil,
resized_crop_bounding_box,
resized_crop_image_tensor, resized_crop_image_tensor,
resized_crop_image_pil, resized_crop_image_pil,
resized_crop_bounding_box, resized_crop_segmentation_mask,
affine_bounding_box, affine_bounding_box,
affine_image_tensor, affine_image_tensor,
affine_image_pil, affine_image_pil,
......
...@@ -555,6 +555,18 @@ def resized_crop_bounding_box( ...@@ -555,6 +555,18 @@ def resized_crop_bounding_box(
return resize_bounding_box(bounding_box, size, (height, width)) return resize_bounding_box(bounding_box, size, (height, width))
def resized_crop_segmentation_mask(
mask: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
) -> torch.Tensor:
mask = crop_segmentation_mask(mask, top, left, height, width)
return resize_segmentation_mask(mask, size)
def _parse_five_crop_size(size: List[int]) -> List[int]: def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number): if isinstance(size, numbers.Number):
size = [int(size), int(size)] size = [int(size), int(size)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment