Unverified Commit f7c7bdf5 authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Stricter SanitizeBoundingBoxes labels_getter heuristic (#7880)

parent 054432d2
......@@ -99,10 +99,9 @@ bboxes = datapoints.BoundingBoxes(
format="XYXY", canvas_size=img.shape[-2:])
transforms = v2.Compose([
v2.RandomPhotometricDistort(),
v2.RandomIoUCrop(),
v2.RandomHorizontalFlip(p=0.5),
v2.SanitizeBoundingBoxes(),
v2.RandomResizedCrop(size=(224, 224), antialias=True),
v2.RandomPhotometricDistort(p=1),
v2.RandomHorizontalFlip(p=1),
])
out_img, out_bboxes = transforms(img, bboxes)
......
......@@ -1256,6 +1256,20 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_labels.tolist() == valid_indices
def test_sanitize_bounding_boxes_no_label():
# Non-regression test for https://github.com/pytorch/vision/issues/7878
img = make_image()
boxes = make_bounding_boxes()
with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
transforms.SanitizeBoundingBoxes()(img, boxes)
out_img, out_boxes = transforms.SanitizeBoundingBoxes(labels_getter=None)(img, boxes)
assert isinstance(out_img, datapoints.Image)
assert isinstance(out_boxes, datapoints.BoundingBoxes)
def test_sanitize_bounding_boxes_errors():
good_bbox = datapoints.BoundingBoxes(
......
......@@ -112,7 +112,7 @@ def _find_labels_default_heuristic(inputs: Any) -> torch.Tensor:
inputs = inputs[1]
# MixUp, CutMix
if isinstance(inputs, torch.Tensor):
if is_pure_tensor(inputs):
return inputs
if not isinstance(inputs, collections.abc.Mapping):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment