Unverified Commit 5e2db86c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added functional `crop_bounding_box` op (#5781)

* [proto] Added crop_bounding_box op

* Removed "pass"

* Updated comment

* Removed unused args from signature
parent 1ac6e8b9
......@@ -321,6 +321,17 @@ def rotate_segmentation_mask():
)
@register_kernel_info_from_sample_inputs_fn
def crop_bounding_box():
for bounding_box, top, left in itertools.product(make_bounding_boxes(), [-8, 0, 9], [-8, 0, 9]):
yield SampleInput(
bounding_box,
format=bounding_box.format,
top=top,
left=left,
)
@pytest.mark.parametrize(
"kernel",
[
......@@ -808,3 +819,44 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False)
torch.testing.assert_close(out_mask, expected_mask)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"top, left, height, width, expected_bboxes",
[
[8, 12, 30, 40, [(-2.0, 7.0, 13.0, 27.0), (38.0, -3.0, 58.0, 14.0), (33.0, 38.0, 44.0, 54.0)]],
[-8, 12, 70, 40, [(-2.0, 23.0, 13.0, 43.0), (38.0, 13.0, 58.0, 30.0), (33.0, 54.0, 44.0, 70.0)]],
],
)
def test_correctness_crop_bounding_box(device, top, left, height, width, expected_bboxes):
# Expected bboxes computed using Albumentations:
# import numpy as np
# from albumentations.augmentations.crops.functional import crop_bbox_by_coords, normalize_bbox, denormalize_bbox
# expected_bboxes = []
# for in_box in in_boxes:
# n_in_box = normalize_bbox(in_box, *size)
# n_out_box = crop_bbox_by_coords(
# n_in_box, (left, top, left + width, top + height), height, width, *size
# )
# out_box = denormalize_bbox(n_out_box, height, width)
# expected_bboxes.append(out_box)
size = (64, 76)
# xyxy format
in_boxes = [
[10.0, 15.0, 25.0, 35.0],
[50.0, 5.0, 70.0, 22.0],
[45.0, 46.0, 56.0, 62.0],
]
in_boxes = features.BoundingBox(in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=size, device=device)
output_boxes = F.crop_bounding_box(
in_boxes,
in_boxes.format,
top,
left,
)
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
......@@ -57,9 +57,10 @@ from ._geometry import (
rotate_image_tensor,
rotate_image_pil,
rotate_segmentation_mask,
pad_bounding_box,
pad_image_tensor,
pad_image_pil,
pad_bounding_box,
crop_bounding_box,
crop_image_tensor,
crop_image_pil,
perspective_image_tensor,
......
......@@ -419,6 +419,27 @@ crop_image_tensor = _FT.crop
crop_image_pil = _FP.crop
def crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
top: int,
left: int,
) -> torch.Tensor:
shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)
# Crop or implicit pad if left and/or top have negative values:
bounding_box[:, 0::2] -= left
bounding_box[:, 1::2] -= top
return convert_bounding_box_format(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(shape)
def perspective_image_tensor(
img: torch.Tensor,
perspective_coeffs: List[float],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment