"git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "dd65ee211ea5ec1c876d323c4387f066bee41a77"
Unverified Commit 226a56b9 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto][tests] Added ref functions for h/v flips (#6876)

* [proto][tests] Added ref functions for h/v flips

* Better dtype handling in reference_affine_bounding_box_helper
parent 8b4bb5f1
...@@ -145,6 +145,29 @@ def sample_inputs_horizontal_flip_video(): ...@@ -145,6 +145,29 @@ def sample_inputs_horizontal_flip_video():
yield ArgsKwargs(video_loader) yield ArgsKwargs(video_loader)
def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[-1, 0, spatial_size[1]],
[0, 1, 0],
],
dtype="float32",
)
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes
def reference_inputs_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -158,6 +181,8 @@ KERNEL_INFOS.extend( ...@@ -158,6 +181,8 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.horizontal_flip_bounding_box, F.horizontal_flip_bounding_box,
sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box, sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box,
reference_fn=reference_horizontal_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box,
), ),
KernelInfo( KernelInfo(
F.horizontal_flip_mask, F.horizontal_flip_mask,
...@@ -409,15 +434,13 @@ def _compute_affine_matrix(angle, translate, scale, shear, center): ...@@ -409,15 +434,13 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix return true_matrix
def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None): def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix):
if center is None: def transform(bbox, affine_matrix_, format_):
center = [s * 0.5 for s in spatial_size[::-1]] # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
def transform(bbox): bbox_xyxy = F.convert_format_bounding_box(
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center) bbox.float(), old_format=format_, new_format=features.BoundingBoxFormat.XYXY, inplace=True
affine_matrix = affine_matrix[:2, :] )
bbox_xyxy = F.convert_format_bounding_box(bbox, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
points = np.array( points = np.array(
[ [
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
...@@ -426,7 +449,7 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, ...@@ -426,7 +449,7 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
] ]
) )
transformed_points = np.matmul(points, affine_matrix.T) transformed_points = np.matmul(points, affine_matrix_.T)
out_bbox = torch.tensor( out_bbox = torch.tensor(
[ [
np.min(transformed_points[:, 0]).item(), np.min(transformed_points[:, 0]).item(),
...@@ -434,14 +457,16 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, ...@@ -434,14 +457,16 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
np.max(transformed_points[:, 0]).item(), np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(), np.max(transformed_points[:, 1]).item(),
], ],
dtype=bbox.dtype,
) )
return F.convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format) out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
return out_bbox.to(dtype=in_dtype)
if bounding_box.ndim < 2: if bounding_box.ndim < 2:
bounding_box = [bounding_box] bounding_box = [bounding_box]
expected_bboxes = [transform(bbox) for bbox in bounding_box] expected_bboxes = [transform(bbox, affine_matrix, format) for bbox in bounding_box]
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
...@@ -450,6 +475,18 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, ...@@ -450,6 +475,18 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
return expected_bboxes return expected_bboxes
def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
if center is None:
center = [s * 0.5 for s in spatial_size[::-1]]
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :]
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes
def reference_inputs_affine_bounding_box(): def reference_inputs_affine_bounding_box():
for bounding_box_loader, affine_kwargs in itertools.product( for bounding_box_loader, affine_kwargs in itertools.product(
make_bounding_box_loaders(extra_dims=[()]), make_bounding_box_loaders(extra_dims=[()]),
...@@ -643,6 +680,20 @@ def sample_inputs_vertical_flip_video(): ...@@ -643,6 +680,20 @@ def sample_inputs_vertical_flip_video():
yield ArgsKwargs(video_loader) yield ArgsKwargs(video_loader)
def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[1, 0, 0],
[0, -1, spatial_size[0]],
],
dtype="float32",
)
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes
KERNEL_INFOS.extend( KERNEL_INFOS.extend(
[ [
KernelInfo( KernelInfo(
...@@ -656,6 +707,8 @@ KERNEL_INFOS.extend( ...@@ -656,6 +707,8 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.vertical_flip_bounding_box, F.vertical_flip_bounding_box,
sample_inputs_fn=sample_inputs_vertical_flip_bounding_box, sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
reference_fn=reference_vertical_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box,
), ),
KernelInfo( KernelInfo(
F.vertical_flip_mask, F.vertical_flip_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment