"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "a027bbf4d7eb73c7448393f84f6181d0ab791a97"
Unverified Commit ed48bb1c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Extend default heuristic of SanitizeBoundingBoxes to support tuples (#7304)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent a46d97c9
...@@ -1935,7 +1935,14 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize): ...@@ -1935,7 +1935,14 @@ def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None) "labels_getter", ("default", "labels", lambda inputs: inputs["labels"], None, lambda inputs: None)
) )
def test_sanitize_bounding_boxes(min_size, labels_getter): @pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
if sample_type is tuple and not isinstance(labels_getter, str):
# The "lambda inputs: inputs["labels"]" labels_getter used in this test
# doesn't work if the input is a tuple.
return
H, W = 256, 128 H, W = 256, 128
boxes_and_validity = [ boxes_and_validity = [
...@@ -1970,35 +1977,56 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): ...@@ -1970,35 +1977,56 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
) )
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
sample = { sample = {
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), "image": input_img,
"labels": labels, "labels": labels,
"boxes": boxes, "boxes": boxes,
"whatever": torch.rand(10), "whatever": whatever,
"None": None, "None": None,
"masks": masks, "masks": masks,
} }
if sample_type is tuple:
img = sample.pop("image")
sample = (img, sample)
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
assert out["image"] is sample["image"] if sample_type is tuple:
assert out["whatever"] is sample["whatever"] out_image = out[0]
out_labels = out[1]["labels"]
out_boxes = out[1]["boxes"]
out_masks = out[1]["masks"]
out_whatever = out[1]["whatever"]
else:
out_image = out["image"]
out_labels = out["labels"]
out_boxes = out["boxes"]
out_masks = out["masks"]
out_whatever = out["whatever"]
assert out_image is input_img
assert out_whatever is whatever
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out["labels"] is sample["labels"] assert out_labels is labels
else: else:
assert isinstance(out["labels"], torch.Tensor) assert isinstance(out_labels, torch.Tensor)
assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0] assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
# This works because we conveniently set labels to arange(num_boxes) # This works because we conveniently set labels to arange(num_boxes)
assert out["labels"].tolist() == valid_indices assert out_labels.tolist() == valid_indices
@pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT"))
def test_sanitize_bounding_boxes_default_heuristic(key): @pytest.mark.parametrize("sample_type", (tuple, dict))
def test_sanitize_bounding_boxes_default_heuristic(key, sample_type):
labels = torch.arange(10) labels = torch.arange(10)
d = {key: labels} sample = {key: labels, "another_key": "whatever"}
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(d) is labels if sample_type is tuple:
sample = (None, sample, "whatever_again")
assert transforms.SanitizeBoundingBoxes._find_labels_default_heuristic(sample) is labels
if key.lower() != "labels": if key.lower() != "labels":
# If "labels" is in the dict (case-insensitive), # If "labels" is in the dict (case-insensitive),
......
import collections import collections
import warnings import warnings
from contextlib import suppress from contextlib import suppress
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union from typing import Any, Callable, cast, Dict, List, Mapping, Optional, Sequence, Type, Union
import PIL.Image import PIL.Image
...@@ -269,7 +269,9 @@ class SanitizeBoundingBoxes(Transform): ...@@ -269,7 +269,9 @@ class SanitizeBoundingBoxes(Transform):
elif callable(labels_getter): elif callable(labels_getter):
self._labels_getter = labels_getter self._labels_getter = labels_getter
elif isinstance(labels_getter, str): elif isinstance(labels_getter, str):
self._labels_getter = lambda inputs: inputs[labels_getter] self._labels_getter = lambda inputs: SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)[
labels_getter # type: ignore[index]
]
elif labels_getter is None: elif labels_getter is None:
self._labels_getter = None self._labels_getter = None
else: else:
...@@ -278,10 +280,27 @@ class SanitizeBoundingBoxes(Transform): ...@@ -278,10 +280,27 @@ class SanitizeBoundingBoxes(Transform):
f"Got {labels_getter} of type {type(labels_getter)}." f"Got {labels_getter} of type {type(labels_getter)}."
) )
@staticmethod
def _get_dict_or_second_tuple_entry(inputs: Any) -> Mapping[str, Any]:
# datasets outputs may be plain dicts like {"img": ..., "labels": ..., "bbox": ...}
# or tuples like (img, {"labels":..., "bbox": ...})
# This hacky helper accounts for both structures.
if isinstance(inputs, tuple):
inputs = inputs[1]
if not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default', "
f"then the input to forward() must be a dict or a tuple whose second element is a dict."
f" Got {type(inputs)} instead."
)
return inputs
@staticmethod @staticmethod
def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]: def _find_labels_default_heuristic(inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
# Tries to find a "label" key, otherwise tries for the first key that contains "label" - case insensitive # Tries to find a "labels" key, otherwise tries for the first key that contains "label" - case insensitive
# Returns None if nothing is found # Returns None if nothing is found
inputs = SanitizeBoundingBoxes._get_dict_or_second_tuple_entry(inputs)
candidate_key = None candidate_key = None
with suppress(StopIteration): with suppress(StopIteration):
candidate_key = next(key for key in inputs.keys() if key.lower() == "labels") candidate_key = next(key for key in inputs.keys() if key.lower() == "labels")
...@@ -298,12 +317,6 @@ class SanitizeBoundingBoxes(Transform): ...@@ -298,12 +317,6 @@ class SanitizeBoundingBoxes(Transform):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
inputs = inputs if len(inputs) > 1 else inputs[0] inputs = inputs if len(inputs) > 1 else inputs[0]
if isinstance(self.labels_getter, str) and not isinstance(inputs, collections.abc.Mapping):
raise ValueError(
f"If labels_getter is a str or 'default' (got {self.labels_getter}), "
f"then the input to forward() must be a dict. Got {type(inputs)} instead."
)
if self._labels_getter is None: if self._labels_getter is None:
labels = None labels = None
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment