Unverified Commit 2ba2f1d5 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Speed-up crop on bboxes and tests (#6881)

* [proto] Speed-up crop on bboxes and tests

* Fix linter

* Update _geometry.py

* Fixed device issue

* Revert changes in test/prototype_transforms_kernel_infos.py

* Fixed failing correctness tests
parent 1921613a
...@@ -862,6 +862,27 @@ def sample_inputs_crop_video(): ...@@ -862,6 +862,27 @@ def sample_inputs_crop_video():
yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8) yield ArgsKwargs(video_loader, top=4, left=3, height=7, width=8)
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):
affine_matrix = np.array(
[
[1, 0, -left],
[0, 1, -top],
],
dtype="float32",
)
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes, (height, width)
def reference_inputs_crop_bounding_box():
for bounding_box_loader, params in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
):
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -875,6 +896,8 @@ KERNEL_INFOS.extend( ...@@ -875,6 +896,8 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.crop_bounding_box, F.crop_bounding_box,
sample_inputs_fn=sample_inputs_crop_bounding_box, sample_inputs_fn=sample_inputs_crop_bounding_box,
reference_fn=reference_crop_bounding_box,
reference_inputs_fn=reference_inputs_crop_bounding_box,
), ),
KernelInfo( KernelInfo(
F.crop_mask, F.crop_mask,
......
...@@ -900,7 +900,8 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -900,7 +900,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
def _compute_expected_bbox(bbox, output_size_): def _compute_expected_bbox(bbox, output_size_):
format_ = bbox.format format_ = bbox.format
spatial_size_ = bbox.spatial_size spatial_size_ = bbox.spatial_size
bbox = convert_format_bounding_box(bbox, format_, features.BoundingBoxFormat.XYWH) dtype = bbox.dtype
bbox = convert_format_bounding_box(bbox.float(), format_, features.BoundingBoxFormat.XYWH)
if len(output_size_) == 1: if len(output_size_) == 1:
output_size_.append(output_size_[-1]) output_size_.append(output_size_[-1])
...@@ -913,14 +914,9 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -913,14 +914,9 @@ def test_correctness_center_crop_bounding_box(device, output_size):
bbox[2].item(), bbox[2].item(),
bbox[3].item(), bbox[3].item(),
] ]
out_bbox = features.BoundingBox( out_bbox = torch.tensor(out_bbox)
out_bbox, out_bbox = convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
format=features.BoundingBoxFormat.XYWH, return out_bbox.to(dtype=dtype, device=bbox.device)
spatial_size=output_size_,
dtype=bbox.dtype,
device=bbox.device,
)
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
......
...@@ -802,22 +802,17 @@ def crop_bounding_box( ...@@ -802,22 +802,17 @@ def crop_bounding_box(
height: int, height: int,
width: int, width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth bounding_box = bounding_box.clone()
bounding_box = convert_format_bounding_box(
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
)
# Crop or implicit pad if left and/or top have negative values: # Crop or implicit pad if left and/or top have negative values:
bounding_box[..., 0::2] -= left if format == features.BoundingBoxFormat.XYXY:
bounding_box[..., 1::2] -= top sub = torch.tensor([left, top, left, top], device=bounding_box.device)
else:
sub = torch.tensor([left, top, 0, 0], device=bounding_box.device)
bounding_box = bounding_box.sub_(sub)
return ( return bounding_box, (height, width)
convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
),
(height, width),
)
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment