Unverified Commit 602e8ca1 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

clamp bounding boxes in some geometry kernels (#7215)


Co-authored-by: default avatarvfdev-5 <vfdev.5@gmail.com>
parent 6af6bf45
...@@ -108,7 +108,7 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False): ...@@ -108,7 +108,7 @@ def float32_vs_uint8_pixel_difference(atol=1, mae=False):
} }
def scripted_vs_eager_double_pixel_difference(device, atol=1e-6, rtol=1e-6): def scripted_vs_eager_float64_tolerances(device, atol=1e-6, rtol=1e-6):
return { return {
(("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False}, (("TestKernels", "test_scripted_vs_eager"), torch.float64, device): {"atol": atol, "rtol": rtol, "mae": False},
} }
...@@ -211,10 +211,12 @@ def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size ...@@ -211,10 +211,12 @@ def reference_horizontal_flip_bounding_box(bounding_box, *, format, spatial_size
[-1, 0, spatial_size[1]], [-1, 0, spatial_size[1]],
[0, 1, 0], [0, 1, 0],
], ],
dtype="float32", dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
) )
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes return expected_bboxes
...@@ -322,7 +324,7 @@ def reference_inputs_resize_image_tensor(): ...@@ -322,7 +324,7 @@ def reference_inputs_resize_image_tensor():
def sample_inputs_resize_bounding_box(): def sample_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders(): for bounding_box_loader in make_bounding_box_loaders():
for size in _get_resize_sizes(bounding_box_loader.spatial_size): for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size) yield ArgsKwargs(bounding_box_loader, spatial_size=bounding_box_loader.spatial_size, size=size)
def sample_inputs_resize_mask(): def sample_inputs_resize_mask():
...@@ -344,19 +346,20 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size= ...@@ -344,19 +346,20 @@ def reference_resize_bounding_box(bounding_box, *, spatial_size, size, max_size=
[new_width / old_width, 0, 0], [new_width / old_width, 0, 0],
[0, new_height / old_height, 0], [0, new_height / old_height, 0],
], ],
dtype="float32", dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
) )
expected_bboxes = reference_affine_bounding_box_helper( expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=datapoints.BoundingBoxFormat.XYXY, affine_matrix=affine_matrix bounding_box,
format=bounding_box.format,
spatial_size=(new_height, new_width),
affine_matrix=affine_matrix,
) )
return expected_bboxes, (new_height, new_width) return expected_bboxes, (new_height, new_width)
def reference_inputs_resize_bounding_box(): def reference_inputs_resize_bounding_box():
for bounding_box_loader in make_bounding_box_loaders( for bounding_box_loader in make_bounding_box_loaders(extra_dims=((), (4,))):
formats=[datapoints.BoundingBoxFormat.XYXY], extra_dims=((), (4,))
):
for size in _get_resize_sizes(bounding_box_loader.spatial_size): for size in _get_resize_sizes(bounding_box_loader.spatial_size):
yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size) yield ArgsKwargs(bounding_box_loader, size=size, spatial_size=bounding_box_loader.spatial_size)
...@@ -543,14 +546,17 @@ def _compute_affine_matrix(angle, translate, scale, shear, center): ...@@ -543,14 +546,17 @@ def _compute_affine_matrix(angle, translate, scale, shear, center):
return true_matrix return true_matrix
def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix): def reference_affine_bounding_box_helper(bounding_box, *, format, spatial_size, affine_matrix):
def transform(bbox, affine_matrix_, format_): def transform(bbox, affine_matrix_, format_, spatial_size_):
# Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1 # Go to float before converting to prevent precision loss in case of CXCYWH -> XYXY and W or H is 1
in_dtype = bbox.dtype in_dtype = bbox.dtype
if not torch.is_floating_point(bbox): if not torch.is_floating_point(bbox):
bbox = bbox.float() bbox = bbox.float()
bbox_xyxy = F.convert_format_bounding_box( bbox_xyxy = F.convert_format_bounding_box(
bbox, old_format=format_, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True bbox.as_subclass(torch.Tensor),
old_format=format_,
new_format=datapoints.BoundingBoxFormat.XYXY,
inplace=True,
) )
points = np.array( points = np.array(
[ [
...@@ -573,12 +579,15 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix) ...@@ -573,12 +579,15 @@ def reference_affine_bounding_box_helper(bounding_box, *, format, affine_matrix)
out_bbox = F.convert_format_bounding_box( out_bbox = F.convert_format_bounding_box(
out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True out_bbox, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format_, inplace=True
) )
return out_bbox.to(dtype=in_dtype) # It is important to clamp before casting, especially for CXCYWH format, dtype=int64
out_bbox = F.clamp_bounding_box(out_bbox, format=format_, spatial_size=spatial_size_)
out_bbox = out_bbox.to(dtype=in_dtype)
return out_bbox
if bounding_box.ndim < 2: if bounding_box.ndim < 2:
bounding_box = [bounding_box] bounding_box = [bounding_box]
expected_bboxes = [transform(bbox, affine_matrix, format) for bbox in bounding_box] expected_bboxes = [transform(bbox, affine_matrix, format, spatial_size) for bbox in bounding_box]
if len(expected_bboxes) > 1: if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
...@@ -594,7 +603,9 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, ...@@ -594,7 +603,9 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center) affine_matrix = _compute_affine_matrix(angle, translate, scale, shear, center)
affine_matrix = affine_matrix[:2, :] affine_matrix = affine_matrix[:2, :]
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes return expected_bboxes
...@@ -643,9 +654,6 @@ KERNEL_INFOS.extend( ...@@ -643,9 +654,6 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_affine_bounding_box, sample_inputs_fn=sample_inputs_affine_bounding_box,
reference_fn=reference_affine_bounding_box, reference_fn=reference_affine_bounding_box,
reference_inputs_fn=reference_inputs_affine_bounding_box, reference_inputs_fn=reference_inputs_affine_bounding_box,
closeness_kwargs={
(("TestKernels", "test_against_reference"), torch.int64, "cpu"): dict(atol=1, rtol=0),
},
test_marks=[ test_marks=[
xfail_jit_python_scalar_arg("shear"), xfail_jit_python_scalar_arg("shear"),
], ],
...@@ -729,10 +737,12 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size): ...@@ -729,10 +737,12 @@ def reference_vertical_flip_bounding_box(bounding_box, *, format, spatial_size):
[1, 0, 0], [1, 0, 0],
[0, -1, spatial_size[0]], [0, -1, spatial_size[0]],
], ],
dtype="float32", dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
) )
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes return expected_bboxes
...@@ -806,6 +816,43 @@ def sample_inputs_rotate_bounding_box(): ...@@ -806,6 +816,43 @@ def sample_inputs_rotate_bounding_box():
) )
def reference_inputs_rotate_bounding_box():
for bounding_box_loader, angle in itertools.product(
make_bounding_box_loaders(extra_dims=((), (4,))), _ROTATE_ANGLES
):
yield ArgsKwargs(
bounding_box_loader,
format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
angle=angle,
)
# TODO: add samples with expand=True and center
def reference_rotate_bounding_box(bounding_box, *, format, spatial_size, angle, expand=False, center=None):
if center is None:
center = [spatial_size[1] * 0.5, spatial_size[0] * 0.5]
a = np.cos(angle * np.pi / 180.0)
b = np.sin(angle * np.pi / 180.0)
cx = center[0]
cy = center[1]
affine_matrix = np.array(
[
[a, b, cx - cx * a - b * cy],
[-b, a, cy + cx * b - a * cy],
],
dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
)
expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes, spatial_size
def sample_inputs_rotate_mask(): def sample_inputs_rotate_mask():
for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]): for mask_loader in make_mask_loaders(sizes=["random"], num_categories=["random"], num_objects=["random"]):
yield ArgsKwargs(mask_loader, angle=15.0) yield ArgsKwargs(mask_loader, angle=15.0)
...@@ -834,9 +881,11 @@ KERNEL_INFOS.extend( ...@@ -834,9 +881,11 @@ KERNEL_INFOS.extend(
KernelInfo( KernelInfo(
F.rotate_bounding_box, F.rotate_bounding_box,
sample_inputs_fn=sample_inputs_rotate_bounding_box, sample_inputs_fn=sample_inputs_rotate_bounding_box,
reference_fn=reference_rotate_bounding_box,
reference_inputs_fn=reference_inputs_rotate_bounding_box,
closeness_kwargs={ closeness_kwargs={
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
}, },
), ),
KernelInfo( KernelInfo(
...@@ -897,17 +946,19 @@ def sample_inputs_crop_video(): ...@@ -897,17 +946,19 @@ def sample_inputs_crop_video():
def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width): def reference_crop_bounding_box(bounding_box, *, format, top, left, height, width):
affine_matrix = np.array( affine_matrix = np.array(
[ [
[1, 0, -left], [1, 0, -left],
[0, 1, -top], [0, 1, -top],
], ],
dtype="float32", dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
) )
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) spatial_size = (height, width)
return expected_bboxes, (height, width) expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=spatial_size, affine_matrix=affine_matrix
)
return expected_bboxes, spatial_size
def reference_inputs_crop_bounding_box(): def reference_inputs_crop_bounding_box():
...@@ -1119,13 +1170,15 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p ...@@ -1119,13 +1170,15 @@ def reference_pad_bounding_box(bounding_box, *, format, spatial_size, padding, p
[1, 0, left], [1, 0, left],
[0, 1, top], [0, 1, top],
], ],
dtype="float32", dtype="float64" if bounding_box.dtype == torch.float64 else "float32",
) )
height = spatial_size[0] + top + bottom height = spatial_size[0] + top + bottom
width = spatial_size[1] + left + right width = spatial_size[1] + left + right
expected_bboxes = reference_affine_bounding_box_helper(bounding_box, format=format, affine_matrix=affine_matrix) expected_bboxes = reference_affine_bounding_box_helper(
bounding_box, format=format, spatial_size=(height, width), affine_matrix=affine_matrix
)
return expected_bboxes, (height, width) return expected_bboxes, (height, width)
...@@ -1225,14 +1278,16 @@ def sample_inputs_perspective_bounding_box(): ...@@ -1225,14 +1278,16 @@ def sample_inputs_perspective_bounding_box():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
startpoints=None, startpoints=None,
endpoints=None, endpoints=None,
coefficients=_PERSPECTIVE_COEFFS[0], coefficients=_PERSPECTIVE_COEFFS[0],
) )
format = datapoints.BoundingBoxFormat.XYXY format = datapoints.BoundingBoxFormat.XYXY
loader = make_bounding_box_loader(format=format)
yield ArgsKwargs( yield ArgsKwargs(
make_bounding_box_loader(format=format), format=format, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS loader, format=format, spatial_size=loader.spatial_size, startpoints=_STARTPOINTS, endpoints=_ENDPOINTS
) )
...@@ -1269,13 +1324,17 @@ KERNEL_INFOS.extend( ...@@ -1269,13 +1324,17 @@ KERNEL_INFOS.extend(
**pil_reference_pixel_difference(2, mae=True), **pil_reference_pixel_difference(2, mae=True),
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
**float32_vs_uint8_pixel_difference(), **float32_vs_uint8_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
}, },
), ),
KernelInfo( KernelInfo(
F.perspective_bounding_box, F.perspective_bounding_box,
sample_inputs_fn=sample_inputs_perspective_bounding_box, sample_inputs_fn=sample_inputs_perspective_bounding_box,
closeness_kwargs={
**scripted_vs_eager_float64_tolerances("cpu", atol=1e-6, rtol=1e-6),
**scripted_vs_eager_float64_tolerances("cuda", atol=1e-6, rtol=1e-6),
},
), ),
KernelInfo( KernelInfo(
F.perspective_mask, F.perspective_mask,
...@@ -1292,8 +1351,8 @@ KERNEL_INFOS.extend( ...@@ -1292,8 +1351,8 @@ KERNEL_INFOS.extend(
sample_inputs_fn=sample_inputs_perspective_video, sample_inputs_fn=sample_inputs_perspective_video,
closeness_kwargs={ closeness_kwargs={
**cuda_vs_cpu_pixel_difference(), **cuda_vs_cpu_pixel_difference(),
**scripted_vs_eager_double_pixel_difference("cpu", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cpu", atol=1e-5, rtol=1e-5),
**scripted_vs_eager_double_pixel_difference("cuda", atol=1e-5, rtol=1e-5), **scripted_vs_eager_float64_tolerances("cuda", atol=1e-5, rtol=1e-5),
}, },
), ),
] ]
...@@ -1331,6 +1390,7 @@ def sample_inputs_elastic_bounding_box(): ...@@ -1331,6 +1390,7 @@ def sample_inputs_elastic_bounding_box():
yield ArgsKwargs( yield ArgsKwargs(
bounding_box_loader, bounding_box_loader,
format=bounding_box_loader.format, format=bounding_box_loader.format,
spatial_size=bounding_box_loader.spatial_size,
displacement=displacement, displacement=displacement,
) )
......
...@@ -146,7 +146,7 @@ class TestSmoke: ...@@ -146,7 +146,7 @@ class TestSmoke:
(transforms.RandomZoomOut(p=1.0), None), (transforms.RandomZoomOut(p=1.0), None),
(transforms.Resize([16, 16], antialias=True), None), (transforms.Resize([16, 16], antialias=True), None),
(transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None), (transforms.ScaleJitter((16, 16), scale_range=(0.8, 1.2)), None),
(transforms.ClampBoundingBoxes(), None), (transforms.ClampBoundingBox(), None),
(transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None), (transforms.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.CXCYWH), None),
(transforms.ConvertDtype(), None), (transforms.ConvertDtype(), None),
(transforms.GaussianBlur(kernel_size=3), None), (transforms.GaussianBlur(kernel_size=3), None),
......
...@@ -25,7 +25,7 @@ from torch.utils._pytree import tree_map ...@@ -25,7 +25,7 @@ from torch.utils._pytree import tree_map
from torchvision.prototype import datapoints from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F from torchvision.prototype.transforms import functional as F
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_format_bounding_box from torchvision.prototype.transforms.functional._meta import clamp_bounding_box, convert_format_bounding_box
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
...@@ -257,16 +257,17 @@ class TestKernels: ...@@ -257,16 +257,17 @@ class TestKernels:
@reference_inputs @reference_inputs
def test_against_reference(self, test_id, info, args_kwargs): def test_against_reference(self, test_id, info, args_kwargs):
(input, *other_args), kwargs = args_kwargs.load("cpu") (input, *other_args), kwargs = args_kwargs.load("cpu")
input = input.as_subclass(torch.Tensor)
actual = info.kernel(input, *other_args, **kwargs) actual = info.kernel(input.as_subclass(torch.Tensor), *other_args, **kwargs)
# We intnetionally don't unwrap the input of the reference function in order for it to have access to all
# metadata regardless of whether the kernel takes it explicitly or not
expected = info.reference_fn(input, *other_args, **kwargs) expected = info.reference_fn(input, *other_args, **kwargs)
assert_close( assert_close(
actual, actual,
expected, expected,
**info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device), **info.get_closeness_kwargs(test_id, dtype=input.dtype, device=input.device),
msg=parametrized_error_message(*other_args, **kwargs), msg=parametrized_error_message(input, *other_args, **kwargs),
) )
@make_info_args_kwargs_parametrization( @make_info_args_kwargs_parametrization(
...@@ -682,6 +683,10 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -682,6 +683,10 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
(48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221), (48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221),
] ]
expected_bboxes = clamp_bounding_box(
datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size)
).tolist()
output_boxes = F.affine_bounding_box( output_boxes = F.affine_bounding_box(
in_boxes, in_boxes,
format=format, format=format,
...@@ -762,7 +767,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -762,7 +767,8 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box(out_bbox, new_format=bbox.format), (height, width) out_bbox = clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format))
return out_bbox, (height, width)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -839,6 +845,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -839,6 +845,9 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
[69.27564928, 12.39339828, 74.93250353, 18.05025253], [69.27564928, 12.39339828, 74.93250353, 18.05025253],
[18.36396103, 1.07968978, 46.64823228, 29.36396103], [18.36396103, 1.07968978, 46.64823228, 29.36396103],
] ]
expected_bboxes = clamp_bounding_box(
datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size)
).tolist()
output_boxes, _ = F.rotate_bounding_box( output_boxes, _ = F.rotate_bounding_box(
in_boxes, in_boxes,
...@@ -905,6 +914,10 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -905,6 +914,10 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
if format != datapoints.BoundingBoxFormat.XYXY: if format != datapoints.BoundingBoxFormat.XYXY:
in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format) in_boxes = convert_format_bounding_box(in_boxes, datapoints.BoundingBoxFormat.XYXY, format)
expected_bboxes = clamp_bounding_box(
datapoints.BoundingBox(expected_bboxes, format="XYXY", spatial_size=spatial_size)
).tolist()
output_boxes, output_spatial_size = F.crop_bounding_box( output_boxes, output_spatial_size = F.crop_bounding_box(
in_boxes, in_boxes,
format, format,
...@@ -1121,7 +1134,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1121,7 +1134,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box(out_bbox, new_format=bbox.format) return clamp_bounding_box(convert_format_bounding_box(out_bbox, new_format=bbox.format))
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -1134,6 +1147,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1134,6 +1147,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
output_bboxes = F.perspective_bounding_box( output_bboxes = F.perspective_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.as_subclass(torch.Tensor),
format=bboxes.format, format=bboxes.format,
spatial_size=bboxes.spatial_size,
startpoints=None, startpoints=None,
endpoints=None, endpoints=None,
coefficients=pcoeffs, coefficients=pcoeffs,
...@@ -1178,6 +1192,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1178,6 +1192,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
] ]
out_bbox = torch.tensor(out_bbox) out_bbox = torch.tensor(out_bbox)
out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_) out_bbox = convert_format_bounding_box(out_bbox, datapoints.BoundingBoxFormat.XYWH, format_)
out_bbox = clamp_bounding_box(out_bbox, format=format_, spatial_size=output_size)
return out_bbox.to(dtype=dtype, device=bbox.device) return out_bbox.to(dtype=dtype, device=bbox.device)
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
...@@ -1201,7 +1216,8 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1201,7 +1216,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
expected_bboxes = torch.stack(expected_bboxes) expected_bboxes = torch.stack(expected_bboxes)
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes)
torch.testing.assert_close(output_boxes, expected_bboxes, atol=1, rtol=0)
torch.testing.assert_close(output_spatial_size, output_size) torch.testing.assert_close(output_spatial_size, output_size)
......
...@@ -81,7 +81,10 @@ class BoundingBox(Datapoint): ...@@ -81,7 +81,10 @@ class BoundingBox(Datapoint):
antialias: Optional[Union[str, bool]] = "warn", antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox: ) -> BoundingBox:
output, spatial_size = self._F.resize_bounding_box( output, spatial_size = self._F.resize_bounding_box(
self.as_subclass(torch.Tensor), spatial_size=self.spatial_size, size=size, max_size=max_size self.as_subclass(torch.Tensor),
spatial_size=self.spatial_size,
size=size,
max_size=max_size,
) )
return BoundingBox.wrap_like(self, output, spatial_size=spatial_size) return BoundingBox.wrap_like(self, output, spatial_size=spatial_size)
...@@ -178,6 +181,7 @@ class BoundingBox(Datapoint): ...@@ -178,6 +181,7 @@ class BoundingBox(Datapoint):
output = self._F.perspective_bounding_box( output = self._F.perspective_bounding_box(
self.as_subclass(torch.Tensor), self.as_subclass(torch.Tensor),
format=self.format, format=self.format,
spatial_size=self.spatial_size,
startpoints=startpoints, startpoints=startpoints,
endpoints=endpoints, endpoints=endpoints,
coefficients=coefficients, coefficients=coefficients,
...@@ -190,5 +194,7 @@ class BoundingBox(Datapoint): ...@@ -190,5 +194,7 @@ class BoundingBox(Datapoint):
interpolation: InterpolationMode = InterpolationMode.BILINEAR, interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None, fill: FillTypeJIT = None,
) -> BoundingBox: ) -> BoundingBox:
output = self._F.elastic_bounding_box(self.as_subclass(torch.Tensor), self.format, displacement) output = self._F.elastic_bounding_box(
self.as_subclass(torch.Tensor), self.format, self.spatial_size, displacement=displacement
)
return BoundingBox.wrap_like(self, output) return BoundingBox.wrap_like(self, output)
...@@ -41,7 +41,7 @@ from ._geometry import ( ...@@ -41,7 +41,7 @@ from ._geometry import (
ScaleJitter, ScaleJitter,
TenCrop, TenCrop,
) )
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import ( from ._misc import (
GaussianBlur, GaussianBlur,
Identity, Identity,
......
...@@ -42,7 +42,7 @@ class ConvertDtype(Transform): ...@@ -42,7 +42,7 @@ class ConvertDtype(Transform):
ConvertImageDtype = ConvertDtype ConvertImageDtype = ConvertDtype
class ClampBoundingBoxes(Transform): class ClampBoundingBox(Transform):
_transformed_types = (datapoints.BoundingBox,) _transformed_types = (datapoints.BoundingBox,)
def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox: def _transform(self, inpt: datapoints.BoundingBox, params: Dict[str, Any]) -> datapoints.BoundingBox:
......
...@@ -22,7 +22,7 @@ from torchvision.transforms.functional_tensor import _pad_symmetric ...@@ -22,7 +22,7 @@ from torchvision.transforms.functional_tensor import _pad_symmetric
from torchvision.utils import _log_api_usage_once from torchvision.utils import _log_api_usage_once
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil from ._meta import clamp_bounding_box, convert_format_bounding_box, get_spatial_size_image_pil
from ._utils import is_simple_tensor from ._utils import is_simple_tensor
...@@ -580,8 +580,9 @@ def affine_image_pil( ...@@ -580,8 +580,9 @@ def affine_image_pil(
return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) return _FP.affine(image, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
def _affine_bounding_box_xyxy( def _affine_bounding_box_with_expand(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int], spatial_size: Tuple[int, int],
angle: Union[int, float], angle: Union[int, float],
translate: List[float], translate: List[float],
...@@ -593,6 +594,17 @@ def _affine_bounding_box_xyxy( ...@@ -593,6 +594,17 @@ def _affine_bounding_box_xyxy(
if bounding_box.numel() == 0: if bounding_box.numel() == 0:
return bounding_box, spatial_size return bounding_box, spatial_size
original_shape = bounding_box.shape
original_dtype = bounding_box.dtype
bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
dtype = bounding_box.dtype
device = bounding_box.device
bounding_box = (
convert_format_bounding_box(
bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
)
).reshape(-1, 4)
angle, translate, shear, center = _affine_parse_args( angle, translate, shear, center = _affine_parse_args(
angle, translate, scale, shear, InterpolationMode.NEAREST, center angle, translate, scale, shear, InterpolationMode.NEAREST, center
) )
...@@ -601,9 +613,6 @@ def _affine_bounding_box_xyxy( ...@@ -601,9 +613,6 @@ def _affine_bounding_box_xyxy(
height, width = spatial_size height, width = spatial_size
center = [width * 0.5, height * 0.5] center = [width * 0.5, height * 0.5]
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device
affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False) affine_vector = _get_inverse_affine_matrix(center, angle, translate, scale, shear, inverted=False)
transposed_affine_matrix = ( transposed_affine_matrix = (
torch.tensor( torch.tensor(
...@@ -651,7 +660,13 @@ def _affine_bounding_box_xyxy( ...@@ -651,7 +660,13 @@ def _affine_bounding_box_xyxy(
new_width, new_height = _compute_affine_output_size(affine_vector, width, height) new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
spatial_size = (new_height, new_width) spatial_size = (new_height, new_width)
return out_bboxes.to(bounding_box.dtype), spatial_size out_bboxes = clamp_bounding_box(out_bboxes, format=datapoints.BoundingBoxFormat.XYXY, spatial_size=spatial_size)
out_bboxes = convert_format_bounding_box(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape)
out_bboxes = out_bboxes.to(original_dtype)
return out_bboxes, spatial_size
def affine_bounding_box( def affine_bounding_box(
...@@ -664,19 +679,18 @@ def affine_bounding_box( ...@@ -664,19 +679,18 @@ def affine_bounding_box(
shear: List[float], shear: List[float],
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
original_shape = bounding_box.shape out_box, _ = _affine_bounding_box_with_expand(
bounding_box,
bounding_box = ( format=format,
convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) spatial_size=spatial_size,
).reshape(-1, 4) angle=angle,
translate=translate,
out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center) scale=scale,
shear=shear,
# out_bboxes should be of shape [N boxes, 4] center=center,
expand=False,
return convert_format_bounding_box( )
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True return out_box
).reshape(original_shape)
def affine_mask( def affine_mask(
...@@ -852,14 +866,10 @@ def rotate_bounding_box( ...@@ -852,14 +866,10 @@ def rotate_bounding_box(
warnings.warn("The provided center argument has no effect on the result if expand is True") warnings.warn("The provided center argument has no effect on the result if expand is True")
center = None center = None
original_shape = bounding_box.shape return _affine_bounding_box_with_expand(
bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
).reshape(-1, 4)
out_bboxes, spatial_size = _affine_bounding_box_xyxy(
bounding_box, bounding_box,
spatial_size, format=format,
spatial_size=spatial_size,
angle=-angle, angle=-angle,
translate=[0.0, 0.0], translate=[0.0, 0.0],
scale=1.0, scale=1.0,
...@@ -868,13 +878,6 @@ def rotate_bounding_box( ...@@ -868,13 +878,6 @@ def rotate_bounding_box(
expand=expand, expand=expand,
) )
return (
convert_format_bounding_box(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
).reshape(original_shape),
spatial_size,
)
def rotate_mask( def rotate_mask(
mask: torch.Tensor, mask: torch.Tensor,
...@@ -1112,8 +1115,9 @@ def pad_bounding_box( ...@@ -1112,8 +1115,9 @@ def pad_bounding_box(
height, width = spatial_size height, width = spatial_size
height += top + bottom height += top + bottom
width += left + right width += left + right
spatial_size = (height, width)
return bounding_box, (height, width) return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
def pad_video( def pad_video(
...@@ -1185,8 +1189,9 @@ def crop_bounding_box( ...@@ -1185,8 +1189,9 @@ def crop_bounding_box(
sub = [left, top, 0, 0] sub = [left, top, 0, 0]
bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device) bounding_box = bounding_box - torch.tensor(sub, dtype=bounding_box.dtype, device=bounding_box.device)
spatial_size = (height, width)
return bounding_box, (height, width) return clamp_bounding_box(bounding_box, format=format, spatial_size=spatial_size), spatial_size
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
...@@ -1332,6 +1337,7 @@ def perspective_image_pil( ...@@ -1332,6 +1337,7 @@ def perspective_image_pil(
def perspective_bounding_box( def perspective_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
startpoints: Optional[List[List[int]]], startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]], endpoints: Optional[List[List[int]]],
coefficients: Optional[List[float]] = None, coefficients: Optional[List[float]] = None,
...@@ -1342,6 +1348,7 @@ def perspective_bounding_box( ...@@ -1342,6 +1348,7 @@ def perspective_bounding_box(
perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients) perspective_coeffs = _perspective_coefficients(startpoints, endpoints, coefficients)
original_shape = bounding_box.shape original_shape = bounding_box.shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
bounding_box = ( bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
...@@ -1408,7 +1415,11 @@ def perspective_bounding_box( ...@@ -1408,7 +1415,11 @@ def perspective_bounding_box(
transformed_points = transformed_points.reshape(-1, 4, 2) transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) out_bboxes = clamp_bounding_box(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
)
# out_bboxes should be of shape [N boxes, 4] # out_bboxes should be of shape [N boxes, 4]
...@@ -1549,6 +1560,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to ...@@ -1549,6 +1560,7 @@ def _create_identity_grid(size: Tuple[int, int], device: torch.device, dtype: to
def elastic_bounding_box( def elastic_bounding_box(
bounding_box: torch.Tensor, bounding_box: torch.Tensor,
format: datapoints.BoundingBoxFormat, format: datapoints.BoundingBoxFormat,
spatial_size: Tuple[int, int],
displacement: torch.Tensor, displacement: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
if bounding_box.numel() == 0: if bounding_box.numel() == 0:
...@@ -1562,14 +1574,11 @@ def elastic_bounding_box( ...@@ -1562,14 +1574,11 @@ def elastic_bounding_box(
displacement = displacement.to(dtype=dtype, device=device) displacement = displacement.to(dtype=dtype, device=device)
original_shape = bounding_box.shape original_shape = bounding_box.shape
# TODO: first cast to float if bbox is int64 before convert_format_bounding_box
bounding_box = ( bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY) convert_format_bounding_box(bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
# Or add spatial_size arg and check displacement shape
spatial_size = displacement.shape[-3], displacement.shape[-2]
id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype) id_grid = _create_identity_grid(spatial_size, device=device, dtype=dtype)
# We construct an approximation of inverse grid as inv_grid = id_grid - displacement # We construct an approximation of inverse grid as inv_grid = id_grid - displacement
# This is not an exact inverse of the grid # This is not an exact inverse of the grid
...@@ -1588,7 +1597,11 @@ def elastic_bounding_box( ...@@ -1588,7 +1597,11 @@ def elastic_bounding_box(
transformed_points = transformed_points.reshape(-1, 4, 2) transformed_points = transformed_points.reshape(-1, 4, 2)
out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1) out_bbox_mins, out_bbox_maxs = torch.aminmax(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) out_bboxes = clamp_bounding_box(
torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype),
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=spatial_size,
)
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True out_bboxes, old_format=datapoints.BoundingBoxFormat.XYXY, new_format=format, inplace=True
...@@ -1796,7 +1809,7 @@ def resized_crop_bounding_box( ...@@ -1796,7 +1809,7 @@ def resized_crop_bounding_box(
size: List[int], size: List[int],
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width) bounding_box, _ = crop_bounding_box(bounding_box, format, top, left, height, width)
return resize_bounding_box(bounding_box, (height, width), size) return resize_bounding_box(bounding_box, spatial_size=(height, width), size=size)
def resized_crop_mask( def resized_crop_mask(
......
...@@ -245,12 +245,17 @@ def _clamp_bounding_box( ...@@ -245,12 +245,17 @@ def _clamp_bounding_box(
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# BoundingBoxFormat instead of converting back and forth # BoundingBoxFormat instead of converting back and forth
in_dtype = bounding_box.dtype
bounding_box = bounding_box.clone() if bounding_box.is_floating_point() else bounding_box.float()
xyxy_boxes = convert_format_bounding_box( xyxy_boxes = convert_format_bounding_box(
bounding_box.clone(), old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True bounding_box, old_format=format, new_format=datapoints.BoundingBoxFormat.XYXY, inplace=True
) )
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True) out_boxes = convert_format_bounding_box(
xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True
)
return out_boxes.to(in_dtype)
def clamp_bounding_box( def clamp_bounding_box(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment