Unverified Commit 1e19d73c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Add SanitizeBoundingBoxes transform (#7246)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent c5e9a10d
import itertools
import pathlib
import random
import re
import warnings
from collections import defaultdict
......@@ -2355,3 +2356,118 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
out["label"] = torch.tensor(out["label"])
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes
@pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
def test_sanitize_bounding_boxes(min_size, labels_getter):
H, W = 256, 128
boxes_and_validity = [
([0, 1, 10, 1], False), # Y1 == Y2
([0, 1, 0, 20], False), # X1 == X2
([0, 0, min_size - 1, 10], False), # H < min_size
([0, 0, 10, min_size - 1], False), # W < min_size
([0, 0, 10, H + 1], False), # Y2 > H
([0, 0, W + 1, 10], False), # X2 > W
([-1, 1, 10, 20], False), # any < 0
([0, 0, -1, 20], False), # any < 0
([0, 0, -10, -1], False), # any < 0
([0, 0, min_size, 10], True), # H < min_size
([0, 0, 10, min_size], True), # W < min_size
([0, 0, W, H], True), # TODO: Is that actually OK?? Should it be -1?
([1, 1, 30, 20], True),
([0, 0, 10, 10], True),
([1, 1, 30, 20], True),
]
random.shuffle(boxes_and_validity) # For test robustness: mix order of wrong and correct cases
boxes, is_valid_mask = zip(*boxes_and_validity)
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]
boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[-2])
boxes = datapoints.BoundingBox(
boxes,
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(H, W),
)
sample = {
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
"labels": labels,
"boxes": boxes,
"whatever": torch.rand(10),
"None": None,
}
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
assert out["image"] is sample["image"]
assert out["whatever"] is sample["whatever"]
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out["labels"] is sample["labels"]
else:
assert isinstance(out["labels"], torch.Tensor)
assert out["boxes"].shape[:-1] == out["labels"].shape
# This works because we conveniently set labels to arange(num_boxes)
assert out["labels"].tolist() == valid_indices
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
def test_sanitize_bounding_boxes_default_heuristic(key):
labels = torch.arange(10)
d = {key: labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
# it takes precedence over other keys which would otherwise be a match
d = {key: "something_else", "labels": labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
def test_sanitize_bounding_boxes_errors():
good_bbox = datapoints.BoundingBox(
[[0, 0, 10, 10]],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(20, 20),
)
with pytest.raises(ValueError, match="min_size must be >= 1"):
transforms.SanitizeBoundingBoxes(min_size=0)
with pytest.raises(ValueError, match="labels_getter should either be a str"):
transforms.SanitizeBoundingBoxes(labels_getter=12)
with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_labels_key = {"bbox": good_bbox, "BAD_KEY": torch.arange(good_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(bad_labels_key)
with pytest.raises(ValueError, match="If labels_getter is a str or 'default'"):
not_a_dict = (good_bbox, torch.arange(good_bbox.shape[0]))
transforms.SanitizeBoundingBoxes()(not_a_dict)
with pytest.raises(ValueError, match="must be a tensor"):
not_a_tensor = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0]).tolist()}
transforms.SanitizeBoundingBoxes()(not_a_tensor)
with pytest.raises(ValueError, match="Number of boxes"):
different_sizes = {"bbox": good_bbox, "labels": torch.arange(good_bbox.shape[0] + 3)}
transforms.SanitizeBoundingBoxes()(different_sizes)
with pytest.raises(ValueError, match="boxes must be of shape"):
bad_bbox = datapoints.BoundingBox( # batch with 2 elements
[
[[0, 0, 10, 10]],
[[0, 0, 10, 10]],
],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=(20, 20),
)
different_sizes = {"bbox": bad_bbox, "labels": torch.arange(bad_bbox.shape[0])}
transforms.SanitizeBoundingBoxes()(different_sizes)
......@@ -49,7 +49,7 @@ from ._misc import (
LinearTransformation,
Normalize,
PermuteDimensions,
RemoveSmallBoundingBoxes,
SanitizeBoundingBoxes,
ToDtype,
TransposeDimensions,
)
......
import collections
import warnings
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, Union
import PIL.Image
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision import transforms as _transforms
from torchvision.ops import remove_small_boxes
from torchvision.prototype import datapoints
from torchvision.prototype.transforms import functional as F, Transform
......@@ -225,28 +227,113 @@ class TransposeDimensions(Transform):
return inpt.transpose(*dims)
class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (datapoints.BoundingBox, datapoints.Mask, datapoints.Label, datapoints.OneHotLabel)
class SanitizeBoundingBoxes(Transform):
# This removes boxes and their corresponding labels:
# - small or degenerate bboxes based on min_size (this includes those where X2 <= X1 or Y2 <= Y1)
# - boxes with any coordinate outside the range of the image (negative, or > spatial_size)
def __init__(self, min_size: float = 1.0) -> None:
def __init__(
self,
min_size: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default",
) -> None:
super().__init__()
if min_size < 1:
raise ValueError(f"min_size must be >= 1, got {min_size}.")
self.min_size = min_size
def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
bounding_box = query_bounding_box(flat_inputs)
# TODO: We can improve performance here by not using the `remove_small_boxes` function. It requires the box to
# be in XYXY format only to calculate the width and height internally. Thus, if the box is in XYWH or CXCYWH
# format,we need to convert first just to afterwards compute the width and height again, although they were
# there in the first place for these formats.
bounding_box = F.convert_format_bounding_box(
bounding_box.as_subclass(torch.Tensor),
old_format=bounding_box.format,
new_format=datapoints.BoundingBoxFormat.XYXY,
)
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)
self.labels_getter = labels_getter
self._labels_getter: Optional[Callable[[Any], Optional[torch.Tensor]]]
if labels_getter == "default":
self._labels_getter = self._find_labels_default_heuristic
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter]
elif labels_getter is None:
self._labels_getter = None
else:
raise ValueError(
"labels_getter should either be a str, callable, or 'default'. "
f"Got {labels_getter} of type {type(labels_getter)}."
)
@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
if candidate_key is None:
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if "label" in key.lower())
if candidate_key is None:
raise ValueError(
"Could not infer where the labels are in the sample. Try passing a callable as the labels_getter parameter?"
"If there are no samples and it is by design, pass labels_getter=None."
)
return inputs[candidate_key]
def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)
if self._labels_getter is None:
labels = None
else:
labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor):
raise ValueError(f"The labels in the input to forward() must be a tensor, got {type(labels)} instead.")
return dict(valid_indices=valid_indices)
flat_inputs, spec = tree_flatten(inputs)
# TODO: this enforces one single BoundingBox entry.
# Assuming this transform needs to be called at the end of *any* pipeline that has bboxes...
# should we just enforce it for all transforms?? What are the benefits of *not* enforcing this?
boxes = query_bounding_box(flat_inputs)
if boxes.ndim != 2:
raise ValueError(f"boxes must be of shape (num_boxes, 4), got {boxes.shape}")
if labels is not None and boxes.shape[0] != labels.shape[0]:
raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match."
)
boxes = cast(
datapoints.BoundingBox,
F.convert_format_bounding_box(
boxes,
new_format=datapoints.BoundingBoxFormat.XYXY,
),
)
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(mask=mask, labels=labels)
flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels
self._transform(inpt, params)
for inpt in flat_inputs
]
return tree_unflatten(flat_outputs, spec)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return inpt.wrap_like(inpt, inpt[params["valid_indices"]])
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox):
inpt = inpt[params["mask"]]
return inpt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment