Unverified Commit 2ba2f1d5 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Speed-up crop on bboxes and tests (#6881)

* [proto] Speed-up crop on bboxes and tests

* Fix linter

* Update _geometry.py

* Fixed device issue

* Revert changes in test/prototype_transforms_kernel_infos.py

* Fixed failing correctness tests
parent 1921613a
......@@ -862,6 +862,27 @@ def sample_inputs_crop_video():
yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):
affine_matrix = np.array(
[
[1, 0, -left],
[0, 1, -top],
],
dtype="float32",
)
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes, (height, width)
def reference_inputs_crop_bounding_box():
for bounding_box_loader, params in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
):
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
KERNEL_INFOS.extend(
[
KernelInfo(
......@@ -875,6 +896,8 @@ KERNEL_INFOS.extend(
KernelInfo(
F.crop_bounding_box,
sample_inputs_fn=sample_inputs_crop_bounding_box,
reference_fn=reference_crop_bounding_box,
reference_inputs_fn=reference_inputs_crop_bounding_box,
),
KernelInfo(
F.crop_mask,
......
......@@ -900,7 +900,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
def _compute_expected_bbox(bbox, output_size_):
format_ = bbox.format
spatial_size_ = bbox.spatial_size
bbox = convert_format_bounding_box(bbox, format_, features.BoundingBoxFormat.XYWH)
dtype = bbox.dtype
bbox = convert_format_bounding_box(bbox.float(), format_, features.BoundingBoxFormat.XYWH)
if len(output_size_) == 1:
output_size_.append(output_size_[-1])
......@@ -913,14 +914,9 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bbox[2].item(),
bbox[3].item(),
]
out_bbox = features.BoundingBox(
out_bbox,
format=features.BoundingBoxFormat.XYWH,
spatial_size=output_size_,
dtype=bbox.dtype,
device=bbox.device,
)
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
return out_bbox.to(dtype=dtype, device=bbox.device)
for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device)
......
......@@ -802,22 +802,17 @@ def crop_bounding_box(
height: int,
width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth
bounding_box = convert_format_bounding_box(
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
)
bounding_box = bounding_box.clone()
# Crop or implicit pad if left and/or top have negative values:
bounding_box[..., 0::2] -= left
bounding_box[..., 1::2] -= top
if format == features.BoundingBoxFormat.XYXY:
sub = torch.tensor([left, top, left, top], device=bounding_box.device)
else:
sub = torch.tensor([left, top, 0, 0], device=bounding_box.device)
bounding_box = bounding_box.sub_(sub)
return (
convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
),
(height, width),
)
return bounding_box, (height, width)
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment