Unverified Commit 2489f370 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

add tests for perspective start- / endpoints (#7226)

parent 0aed8329
......@@ -15,11 +15,14 @@ from prototype_common_utils import (
get_num_channels,
ImageLoader,
InfoBase,
make_bounding_box_loader,
make_bounding_box_loaders,
make_detection_mask_loader,
make_image_loader,
make_image_loaders,
make_image_loaders_for_interpolation,
make_mask_loaders,
make_video_loader,
make_video_loaders,
mark_framework_limitation,
TestMark,
......@@ -1168,12 +1171,18 @@ _PERSPECTIVE_COEFFS = [
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
]
_STARTPOINTS = [[0, 1], [2, 3], [4, 5], [6, 7]]
_ENDPOINTS = [[9, 8], [7, 6], [5, 4], [3, 2]]
def sample_inputs_perspective_image_tensor():
for image_loader in make_image_loaders(sizes=["random"]):
for fill in get_fills(num_channels=image_loader.num_channels, dtype=image_loader.dtype):
yield ArgsKwargs(image_loader, None, None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(
image_loader, startpoints=None, endpoints=None, fill=fill, coefficients=_PERSPECTIVE_COEFFS[0]
)
yield ArgsKwargs(make_image_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
def reference_inputs_perspective_image_tensor():
......@@ -1200,25 +1209,38 @@ def reference_inputs_perspective_image_tensor():
def sample_inputs_perspective_bounding_box():
for bounding_box_loader in make_bounding_box_loaders():
yield ArgsKwargs(
bounding_box_loader, bounding_box_loader.format, None, None, coefficients=_PERSPECTIVE_COEFFS[0]
bounding_box_loader,
format=bounding_box_loader.format,
startpoints=None,
endpoints=None,
coefficients=_PERSPECTIVE_COEFFS[0],
)
format = datapoints.BoundingBoxFormat.XYXY
yield ArgsKwargs(
make_bounding_box_loader(format=format), format=format, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
)
def sample_inputs_perspective_mask():
for mask_loader in make_mask_loaders(sizes=["random"]):
yield ArgsKwargs(mask_loader, None, None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(make_detection_mask_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
def reference_inputs_perspective_mask():
for mask_loader, perspective_coeffs in itertools.product(
make_mask_loaders(extra_dims=[()], num_objects=[1]), _PERSPECTIVE_COEFFS
):
yield ArgsKwargs(mask_loader, None, None, coefficients=perspective_coeffs)
yield ArgsKwargs(mask_loader, startpoints=None, endpoints=None, coefficients=perspective_coeffs)
def sample_inputs_perspective_video():
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
yield ArgsKwargs(video_loader, None, None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(video_loader, startpoints=None, endpoints=None, coefficients=_PERSPECTIVE_COEFFS[0])
yield ArgsKwargs(make_video_loader(), startpoints=_STARTPOINTS, endpoints=_ENDPOINTS)
KERNEL_INFOS.extend(
......
......@@ -176,7 +176,11 @@ class BoundingBox(Datapoint):
coefficients: Optional[List[float]] = None,
) -> BoundingBox:
output = self._F.perspective_bounding_box(
self.as_subclass(torch.Tensor), startpoints, endpoints, self.format, coefficients=coefficients
self.as_subclass(torch.Tensor),
format=self.format,
startpoints=startpoints,
endpoints=endpoints,
coefficients=coefficients,
)
return BoundingBox.wrap_like(self, output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment