Unverified Commit fa82fd3b authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Allow SanitizeBoundingBoxes to sanitize more labels (#8319)

parent 53869eb8
...@@ -5706,7 +5706,17 @@ class TestSanitizeBoundingBoxes: ...@@ -5706,7 +5706,17 @@ class TestSanitizeBoundingBoxes:
return boxes, expected_valid_mask return boxes, expected_valid_mask
@pytest.mark.parametrize("min_size", (1, 10)) @pytest.mark.parametrize("min_size", (1, 10))
@pytest.mark.parametrize("labels_getter", ("default", lambda inputs: inputs["labels"], None, lambda inputs: None)) @pytest.mark.parametrize(
"labels_getter",
(
"default",
lambda inputs: inputs["labels"],
lambda inputs: (inputs["labels"], inputs["other_labels"]),
lambda inputs: [inputs["labels"], inputs["other_labels"]],
None,
lambda inputs: None,
),
)
@pytest.mark.parametrize("sample_type", (tuple, dict)) @pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, min_size, labels_getter, sample_type): def test_transform(self, min_size, labels_getter, sample_type):
...@@ -5721,12 +5731,16 @@ class TestSanitizeBoundingBoxes: ...@@ -5721,12 +5731,16 @@ class TestSanitizeBoundingBoxes:
labels = torch.arange(boxes.shape[0]) labels = torch.arange(boxes.shape[0])
masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W))) masks = tv_tensors.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
# other_labels corresponds to properties from COCO like iscrowd, area...
# We only sanitize it when labels_getter returns a tuple
other_labels = torch.arange(boxes.shape[0])
whatever = torch.rand(10) whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8) input_img = torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8)
sample = { sample = {
"image": input_img, "image": input_img,
"labels": labels, "labels": labels,
"boxes": boxes, "boxes": boxes,
"other_labels": other_labels,
"whatever": whatever, "whatever": whatever,
"None": None, "None": None,
"masks": masks, "masks": masks,
...@@ -5741,12 +5755,14 @@ class TestSanitizeBoundingBoxes: ...@@ -5741,12 +5755,14 @@ class TestSanitizeBoundingBoxes:
if sample_type is tuple: if sample_type is tuple:
out_image = out[0] out_image = out[0]
out_labels = out[1]["labels"] out_labels = out[1]["labels"]
out_other_labels = out[1]["other_labels"]
out_boxes = out[1]["boxes"] out_boxes = out[1]["boxes"]
out_masks = out[1]["masks"] out_masks = out[1]["masks"]
out_whatever = out[1]["whatever"] out_whatever = out[1]["whatever"]
else: else:
out_image = out["image"] out_image = out["image"]
out_labels = out["labels"] out_labels = out["labels"]
out_other_labels = out["other_labels"]
out_boxes = out["boxes"] out_boxes = out["boxes"]
out_masks = out["masks"] out_masks = out["masks"]
out_whatever = out["whatever"] out_whatever = out["whatever"]
...@@ -5757,14 +5773,20 @@ class TestSanitizeBoundingBoxes: ...@@ -5757,14 +5773,20 @@ class TestSanitizeBoundingBoxes:
assert isinstance(out_boxes, tv_tensors.BoundingBoxes) assert isinstance(out_boxes, tv_tensors.BoundingBoxes)
assert isinstance(out_masks, tv_tensors.Mask) assert isinstance(out_masks, tv_tensors.Mask)
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None):
assert out_labels is labels assert out_labels is labels
assert out_other_labels is other_labels
else: else:
assert isinstance(out_labels, torch.Tensor) assert isinstance(out_labels, torch.Tensor)
assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0] assert out_boxes.shape[0] == out_labels.shape[0] == out_masks.shape[0]
# This works because we conveniently set labels to arange(num_boxes) # This works because we conveniently set labels to arange(num_boxes)
assert out_labels.tolist() == valid_indices assert out_labels.tolist() == valid_indices
if callable(labels_getter) and isinstance(labels_getter(sample), (tuple, list)):
assert_equal(out_other_labels, out_labels)
else:
assert_equal(out_other_labels, other_labels)
@pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes)) @pytest.mark.parametrize("input_type", (torch.Tensor, tv_tensors.BoundingBoxes))
def test_functional(self, input_type): def test_functional(self, input_type):
# Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some # Note: the "functional" F.sanitize_bounding_boxes was added after the class, so there is some
......
...@@ -321,6 +321,9 @@ class SanitizeBoundingBoxes(Transform): ...@@ -321,6 +321,9 @@ class SanitizeBoundingBoxes(Transform):
- have any coordinate outside of their corresponding image. You may want to - have any coordinate outside of their corresponding image. You may want to
call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals. call :class:`~torchvision.transforms.v2.ClampBoundingBoxes` first to avoid undesired removals.
It can also sanitize other tensors like the "iscrowd" or "area" properties from COCO
(see ``labels_getter`` parameter).
It is recommended to call it at the end of a pipeline, before passing the It is recommended to call it at the end of a pipeline, before passing the
input to the models. It is critical to call this transform if input to the models. It is critical to call this transform if
:class:`~torchvision.transforms.v2.RandomIoUCrop` was called. :class:`~torchvision.transforms.v2.RandomIoUCrop` was called.
...@@ -330,18 +333,26 @@ class SanitizeBoundingBoxes(Transform): ...@@ -330,18 +333,26 @@ class SanitizeBoundingBoxes(Transform):
Args: Args:
min_size (float, optional) The size below which bounding boxes are removed. Default is 1. min_size (float, optional) The size below which bounding boxes are removed. Default is 1.
labels_getter (callable or str or None, optional): indicates how to identify the labels in the input. labels_getter (callable or str or None, optional): indicates how to identify the labels in the input
(or anything else that needs to be sanitized along with the bounding boxes).
By default, this will try to find a "labels" key in the input (case-insensitive), if By default, this will try to find a "labels" key in the input (case-insensitive), if
the input is a dict or it is a tuple whose second element is a dict. the input is a dict or it is a tuple whose second element is a dict.
This heuristic should work well with a lot of datasets, including the built-in torchvision datasets. This heuristic should work well with a lot of datasets, including the built-in torchvision datasets.
It can also be a callable that takes the same input
as the transform, and returns the labels. It can also be a callable that takes the same input as the transform, and returns either:
- A single tensor (the labels)
- A tuple/list of tensors, each of which will be subject to the same sanitization as the bounding boxes.
This is useful to sanitize multiple tensors like the labels, and the "iscrowd" or "area" properties
from COCO.
If ``labels_getter`` is None then only bounding boxes are sanitized.
""" """
def __init__( def __init__(
self, self,
min_size: float = 1.0, min_size: float = 1.0,
labels_getter: Union[Callable[[Any], Optional[torch.Tensor]], str, None] = "default", labels_getter: Union[Callable[[Any], Any], str, None] = "default",
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -356,17 +367,27 @@ class SanitizeBoundingBoxes(Transform): ...@@ -356,17 +367,27 @@ class SanitizeBoundingBoxes(Transform):
inputs = inputs if len(inputs) > 1 else inputs[0] inputs = inputs if len(inputs) > 1 else inputs[0]
labels = self._labels_getter(inputs) labels = self._labels_getter(inputs)
if labels is not None and not isinstance(labels, torch.Tensor): if labels is not None:
raise ValueError( msg = "The labels in the input to forward() must be a tensor or None, got {type} instead."
f"The labels in the input to forward() must be a tensor or None, got {type(labels)} instead." if isinstance(labels, torch.Tensor):
) labels = (labels,)
elif isinstance(labels, (tuple, list)):
for entry in labels:
if not isinstance(entry, torch.Tensor):
# TODO: we don't need to enforce tensors, just that entries are indexable as t[bool_mask]
raise ValueError(msg.format(type=type(entry)))
else:
raise ValueError(msg.format(type=type(labels)))
flat_inputs, spec = tree_flatten(inputs) flat_inputs, spec = tree_flatten(inputs)
boxes = get_bounding_boxes(flat_inputs) boxes = get_bounding_boxes(flat_inputs)
if labels is not None and boxes.shape[0] != labels.shape[0]: if labels is not None:
for label in labels:
if boxes.shape[0] != label.shape[0]:
raise ValueError( raise ValueError(
f"Number of boxes (shape={boxes.shape}) and number of labels (shape={labels.shape}) do not match." f"Number of boxes (shape={boxes.shape}) and must match the number of labels."
f"Found labels with shape={label.shape})."
) )
valid = F._misc._get_sanitize_bounding_boxes_mask( valid = F._misc._get_sanitize_bounding_boxes_mask(
...@@ -381,7 +402,7 @@ class SanitizeBoundingBoxes(Transform): ...@@ -381,7 +402,7 @@ class SanitizeBoundingBoxes(Transform):
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"] is_label = params["labels"] is not None and any(inpt is label for label in params["labels"])
is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask)) is_bounding_boxes_or_mask = isinstance(inpt, (tv_tensors.BoundingBoxes, tv_tensors.Mask))
if not (is_label or is_bounding_boxes_or_mask): if not (is_label or is_bounding_boxes_or_mask):
...@@ -391,5 +412,5 @@ class SanitizeBoundingBoxes(Transform): ...@@ -391,5 +412,5 @@ class SanitizeBoundingBoxes(Transform):
if is_label: if is_label:
return output return output
else:
return tv_tensors.wrap(output, like=inpt) return tv_tensors.wrap(output, like=inpt)
...@@ -4,7 +4,7 @@ import collections.abc ...@@ -4,7 +4,7 @@ import collections.abc
import numbers import numbers
from contextlib import suppress from contextlib import suppress
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union from typing import Any, Callable, Dict, List, Literal, Sequence, Tuple, Type, Union
import PIL.Image import PIL.Image
import torch import torch
...@@ -139,9 +139,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor: ...@@ -139,9 +139,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
return inputs[candidate_key] return inputs[candidate_key]
def _parse_labels_getter( def _parse_labels_getter(labels_getter: Union[str, Callable[[Any], Any], None]) -> Callable[[Any], Any]:
labels_getter: Union[str, Callable[[Any], Optional[torch.Tensor]], None]
) -> Callable[[Any], Optional[torch.Tensor]]:
if labels_getter == "default": if labels_getter == "default":
return _find_labels_default_heuristic return _find_labels_default_heuristic
elif callable(labels_getter): elif callable(labels_getter):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment