Unverified Commit ed48bb1c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Extend default heuristic of SanitizeBoundingBoxes to support tuples (#7304)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent a46d97c9
......@@ -1935,7 +1935,14 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
@pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
)
def test_sanitize_bounding_boxes(min_size, labels_getter):
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return
H, W = 256, 128
boxes_and_validity = [
......@@ -1970,35 +1977,56 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
)
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
sample = {
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
"image": input_img,
"labels": labels,
"boxes": boxes,
"whatever": torch.rand(10),
"whatever": whatever,
"None": None,
"masks": masks,
}
if sample_type is tuple:
img = sample.pop("image")
sample = (img, sample)
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
assert out["image"] is sample["image"]
assert out["whatever"] is sample["whatever"]
if sample_type is tuple:
out_image = out[0]
out_labels = out[1]["labels"]
out_boxes = out[1]["boxes"]
out_masks = out[1]["masks"]
out_whatever = out[1]["whatever"]
else:
out_image = out["image"]
out_labels = out["labels"]
out_boxes = out["boxes"]
out_masks = out["masks"]
out_whatever = out["whatever"]
assert out_image is input_img
assert out_whatever is whatever
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out["labels"] is sample["labels"]
assert out_labels is labels
else:
assert isinstance(out["labels"], torch.Tensor)
assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0]
assert isinstance(out_labels, torch.Tensor)
assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
# This works because we conveniently set labels to arange(num_boxes)
assert out["labels"].tolist() == valid_indices
assert out_labels.tolist() == valid_indices
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
def test_sanitize_bounding_boxes_default_heuristic(key):
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
labels = torch.arange(10)
d = {key: labels}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels
sample = {key: labels, "another_key": "whatever"}
if sample_type is tuple:
sample = (None, sample, "whatever_again")
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels
if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive),
......
import collections
import warnings
from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union
from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union
import PIL.Image
......@@ -269,7 +269,9 @@ class SanitizeBoundingBoxes(Transform):
elif callable(labels_getter):
self._labels_getter = labels_getter
elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter]
self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None:
self._labels_getter = None
else:
......@@ -278,10 +280,27 @@ class SanitizeBoundingBoxes(Transform):
f"Got {labels_getter} of type {type(labels_getter)}."
)
@staticmethod
def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
# datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...}
# or tuples like (img, {"labels":..., "bbox": ...})
# This hacky helper accounts for both structures.
if isinstance(inputs, tuple):
inputs = inputs[1]
if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default', "
f"then the input to forward() must be a dict or a tuple whose second element is a dict."
f" Got {type(inputs)} instead."
)
return inputs
@staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive
# Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found
inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)
candidate_key = None
with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
......@@ -298,12 +317,6 @@ class SanitizeBoundingBoxes(Transform):
def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0]
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)
if self._labels_getter is None:
labels = None
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment