"tests/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "326e4ccb5bc3de3bdee97cc5907e2e2c75bdcc73"
Unverified Commit 7d0d7fd7 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added `center_crop_bounding_box` functional op (#5972)

* [proto] Added `center_crop_bounding_box` functional op

* Fixed mypy issue

* Added one more test case

* More test cases
parent f079f5a5
...@@ -95,7 +95,7 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch ...@@ -95,7 +95,7 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch
cx = torch.randint(1, width - 1, ()) cx = torch.randint(1, width - 1, ())
cy = torch.randint(1, height - 1, ()) cy = torch.randint(1, height - 1, ())
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1) w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1) h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
parts = (cx, cy, w, h) parts = (cx, cy, w, h)
else: else:
raise pytest.UsageError() raise pytest.UsageError()
...@@ -413,6 +413,14 @@ def perspective_segmentation_mask(): ...@@ -413,6 +413,14 @@ def perspective_segmentation_mask():
) )
@register_kernel_info_from_sample_inputs_fn
def center_crop_bounding_box():
for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]):
yield SampleInput(
bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel", "kernel",
[ [
...@@ -1273,3 +1281,59 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint ...@@ -1273,3 +1281,59 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint
else: else:
expected_masks = expected_masks[0] expected_masks = expected_masks[0]
torch.testing.assert_close(output_mask, expected_masks) torch.testing.assert_close(output_mask, expected_masks)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"output_size",
[(18, 18), [18, 15], (16, 19), [12], [46, 48]],
)
def test_correctness_center_crop_bounding_box(device, output_size):
def _compute_expected_bbox(bbox, output_size_):
format_ = bbox.format
image_size_ = bbox.image_size
bbox = convert_bounding_box_format(bbox, format_, features.BoundingBoxFormat.XYWH)
if len(output_size_) == 1:
output_size_.append(output_size_[-1])
cy = int(round((image_size_[0] - output_size_[0]) * 0.5))
cx = int(round((image_size_[1] - output_size_[1]) * 0.5))
out_bbox = [
bbox[0].item() - cx,
bbox[1].item() - cy,
bbox[2].item(),
bbox[3].item(),
]
out_bbox = features.BoundingBox(
out_bbox,
format=features.BoundingBoxFormat.XYWH,
image_size=output_size_,
dtype=bbox.dtype,
device=bbox.device,
)
return convert_bounding_box_format(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False)
for bboxes in make_bounding_boxes(
image_sizes=[(32, 32), (24, 33), (32, 25)],
extra_dims=((4,),),
):
bboxes = bboxes.to(device)
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size
output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, output_size, bboxes_image_size)
if bboxes.ndim < 2:
bboxes = [bboxes]
expected_bboxes = []
for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes)
...@@ -45,6 +45,7 @@ from ._geometry import ( ...@@ -45,6 +45,7 @@ from ._geometry import (
resize_image_tensor, resize_image_tensor,
resize_image_pil, resize_image_pil,
resize_segmentation_mask, resize_segmentation_mask,
center_crop_bounding_box,
center_crop_image_tensor, center_crop_image_tensor,
center_crop_image_pil, center_crop_image_pil,
resized_crop_bounding_box, resized_crop_bounding_box,
......
...@@ -619,6 +619,17 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I ...@@ -619,6 +619,17 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I
return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width) return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width)
def center_crop_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
output_size: List[int],
image_size: Tuple[int, int],
) -> torch.Tensor:
crop_height, crop_width = _center_crop_parse_output_size(output_size)
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size)
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)
def resized_crop_image_tensor( def resized_crop_image_tensor(
img: torch.Tensor, img: torch.Tensor,
top: int, top: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment