Unverified Commit 226a56b9 authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto][tests] Added ref functions for h/v flips (#6876)

* [proto][tests] Added ref functions for h/v flips

* Better dtype handling in reference_affine_bounding_box_helper
parent 8b4bb5f1
......@@ -145,6 +145,29 @@ def sample_inputs_horizontal_flip_video():
yield ArgsKwargs(video_loader)
def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[-1, 0, spatial_size[1]],
[0, 1, 0],
],
dtype="float32",
)
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes
def reference_inputs_flip_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(extra_dims=[()]):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
)
KERNEL_INFOS.extend(
[
KernelInfo(
......@@ -158,6 +181,8 @@ KERNEL_INFOS.extend(
KernelInfo(
F.horizontal_flip_bounding_box,
sample_inputs_fn=sample_inputs_horizontal_flip_bounding_box,
reference_fn=reference_horizontal_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box,
),
KernelInfo(
F.horizontal_flip_mask,
......@@ -409,15 +434,13 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix
def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
if center is None:
center = [s * 0.5 for s in spatial_size[::-1]]
def transform(bbox):
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :]
bbox_xyxy = F.convert_format_bounding_box(bbox, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix):
def transform(bbox, affine_matrix_, format_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype
bbox_xyxy = F.convert_format_bounding_box(
bbox.float(), old_format=format_, new_format=features.BoundingBoxFormat.XYXY, inplace=True
)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
......@@ -426,7 +449,7 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
transformed_points = np.matmul(points, affine_matrix_.T)
out_bbox = torch.tensor(
[
np.min(transformed_points[:, 0]).item(),
......@@ -434,14 +457,16 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
np.max(transformed_points[:, 0]).item(),
np.max(transformed_points[:, 1]).item(),
],
dtype=bbox.dtype,
)
return F.convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format)
out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
return out_bbox.to(dtype=in_dtype)
if bounding_box.ndim < 2:
bounding_box = [bounding_box]
expected_bboxes = [transform(bbox) for bbox in bounding_box]
expected_bboxes = [transform(bbox, affine_matrix, format) for bbox in bounding_box]
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
......@@ -450,6 +475,18 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
return expected_bboxes
def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, translate, scale, shear, center=None):
if center is None:
center = [s * 0.5 for s in spatial_size[::-1]]
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :]
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes
def reference_inputs_affine_bounding_box():
for bounding_box_loader, affine_kwargs in itertools.product(
make_bounding_box_loaders(extra_dims=[()]),
......@@ -643,6 +680,20 @@ def sample_inputs_vertical_flip_video():
yield ArgsKwargs(video_loader)
def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
affine_matrix = np.array(
[
[1, 0, 0],
[0, -1, spatial_size[0]],
],
dtype="float32",
)
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix)
return expected_bboxes
KERNEL_INFOS.extend(
[
KernelInfo(
......@@ -656,6 +707,8 @@ KERNEL_INFOS.extend(
KernelInfo(
F.vertical_flip_bounding_box,
sample_inputs_fn=sample_inputs_vertical_flip_bounding_box,
reference_fn=reference_vertical_flip_bounding_box,
reference_inputs_fn=reference_inputs_flip_bounding_box,
),
KernelInfo(
F.vertical_flip_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment