Unverified Commit 2b70774e authored by mpearce25's avatar mpearce25 Committed by GitHub
Browse files

Singular Sanitize BoundingBox (#7316)


Co-authored-by: default avatarNicolas Hug <contact@nicolas-hug.com>
parent 0daffad3
......@@ -105,13 +105,13 @@ transform = transforms.Compose(
transforms.RandomHorizontalFlip(),
transforms.ToImageTensor(),
transforms.ConvertImageDtype(torch.float32),
transforms.SanitizeBoundingBoxes(),
transforms.SanitizeBoundingBox(),
]
)
########################################################################################################################
# .. note::
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes` transform is a no-op in this example, but it
# Although the :class:`~torchvision.transforms.v2.SanitizeBoundingBox` transform is a no-op in this example, but it
# should be placed at least once at the end of a detection pipeline to remove degenerate bounding boxes as well as
# the corresponding labels and optionally masks. It is particularly critical to add it if
# :class:`~torchvision.transforms.v2.RandomIoUCrop` was used.
......
......@@ -275,7 +275,7 @@ class TestSmoke:
boxes=datapoints.BoundingBox([[0, 0, 0, 0]], format=format, spatial_size=(224, 244)),
labels=torch.tensor([3]),
)
assert transforms.SanitizeBoundingBoxes()(sample)["boxes"].shape == (0, 4)
assert transforms.SanitizeBoundingBox()(sample)["boxes"].shape == (0, 4)
@parametrize(
[
......@@ -1876,7 +1876,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
transforms.ConvertImageDtype(torch.float),
]
if sanitize:
t += [transforms.SanitizeBoundingBoxes()]
t += [transforms.SanitizeBoundingBox()]
t = transforms.Compose(t)
num_boxes = 5
......@@ -1917,7 +1917,7 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
# ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
# SanitizeBoundingBox(), which we add to the pipelines if the sanitize
# param is True.
# Note that the values below are probably specific to the random seed
# set above (which is fine).
......@@ -1989,7 +1989,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
img = sample.pop("image")
sample = (img, sample)
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
out = transforms.SanitizeBoundingBox(min_size=min_size, labels_getter=labels_getter)(sample)
if sample_type is tuple:
out_image = out[0]
......@@ -2023,13 +2023,13 @@ def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
sample = {key: labels, "another_key": "whatever"}
if sample_type is tuple:
sample = (None, sample, "whatever_again")
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(sample) is labels
if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
assert transforms.SanitizeBoundingBox._find_labels_default_heuristic(d) is labels
def test_sanitize_bounding_boxes_errors():
......@@ -2041,25 +2041,25 @@ def test_sanitize_bounding_boxes_errors():
)
with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBoxes(min_size=0)
transforms.SanitizeBoundingBox(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"):
transforms.SanitizeBoundingBoxes(labels_getter=12)
transforms.SanitizeBoundingBox(labels_getter=12)
with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(bad_labels_key)
transforms.SanitizeBoundingBox()(bad_labels_key)
with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
transforms.SanitizeBoundingBoxes()(not_a_dict)
transforms.SanitizeBoundingBox()(not_a_dict)
with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBoxes()(not_a_tensor)
transforms.SanitizeBoundingBox()(not_a_tensor)
with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)
transforms.SanitizeBoundingBox()(different_sizes)
with pytest.raises(ValueError, match="boxes must be of shape"):
bad_bbox = datapoints.BoundingBox( # batch with 2 elements
......@@ -2071,7 +2071,7 @@ def test_sanitize_bounding_boxes_errors():
spatial_size=(20, 20),
)
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes)
transforms.SanitizeBoundingBox()(different_sizes)
@pytest.mark.parametrize(
......
......@@ -1099,7 +1099,7 @@ class TestRefDetTransforms:
v2_transforms.Compose(
[
v2_transforms.RandomIoUCrop(),
v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
v2_transforms.SanitizeBoundingBox(labels_getter=lambda sample: sample[1]["labels"]),
]
),
{"with_mask": False},
......
......@@ -40,7 +40,7 @@ from ._geometry import (
TenCrop,
)
from ._meta import ClampBoundingBox, ConvertBoundingBoxFormat, ConvertDtype, ConvertImageDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBoxes, ToDtype
from ._misc import GaussianBlur, Identity, Lambda, LinearTransformation, Normalize, SanitizeBoundingBox, ToDtype
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImagePIL, ToImageTensor, ToPILImage
......
......@@ -1114,7 +1114,7 @@ class RandomIoUCrop(Transform):
.. warning::
In order to properly remove the bounding boxes below the IoU threshold, `RandomIoUCrop`
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBoxes`, either immediately
must be followed by :class:`~torchvision.transforms.v2.SanitizeBoundingBox`, either immediately
after or later in the transforms pipeline.
If the input is a :class:`torch.Tensor` or a ``Datapoint`` (e.g. :class:`~torchvision.datapoints.Image`,
......@@ -1222,7 +1222,7 @@ class RandomIoUCrop(Transform):
if isinstance(output, datapoints.BoundingBox):
# We "mark" the invalid boxes as degenreate, and they can be
# removed by a later call to SanitizeBoundingBoxes()
# removed by a later call to SanitizeBoundingBox()
output[~params["is_within_crop_area"]] = 0
return output
......
......@@ -246,7 +246,7 @@ class ToDtype(Transform):
return inpt.to(dtype=dtype)
class SanitizeBoundingBoxes(Transform):
class SanitizeBoundingBox(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
......@@ -269,7 +269,7 @@ class SanitizeBoundingBoxes(Transform):
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[
self._labels_getter = lambda inputs: SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None:
......@@ -300,7 +300,7 @@ class SanitizeBoundingBoxes(Transform):
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)
inputs = SanitizeBoundingBox._get_dict_or_second_tuple_entry(inputs)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment