"tests/python/common/test_heterograph-kernel.py" did not exist on "653428bdc7880ebc45b759e675df09ae6eb146f8"
Unverified Commit b7892d3a authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Make RandomIoUCrop compatible with SanitizeBoundingBoxes (#7268)


Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent d4d20f01
...@@ -1488,16 +1488,13 @@ class TestRandomIoUCrop: ...@@ -1488,16 +1488,13 @@ class TestRandomIoUCrop:
fn.assert_has_calls(expected_calls) fn.assert_has_calls(expected_calls)
expected_within_targets = sum(is_within_crop_area)
# check number of bboxes vs number of labels: # check number of bboxes vs number of labels:
output_bboxes = output[1] output_bboxes = output[1]
assert isinstance(output_bboxes, datapoints.BoundingBox) assert isinstance(output_bboxes, datapoints.BoundingBox)
assert len(output_bboxes) == expected_within_targets assert (output_bboxes[~is_within_crop_area] == 0).all()
output_masks = output[2] output_masks = output[2]
assert isinstance(output_masks, datapoints.Mask) assert isinstance(output_masks, datapoints.Mask)
assert len(output_masks) == expected_within_targets
class TestScaleJitter: class TestScaleJitter:
...@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor): ...@@ -2253,10 +2250,11 @@ def test_classif_preset(image_type, label_type, dataset_return_type, to_tensor):
@pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image)) @pytest.mark.parametrize("image_type", (PIL.Image, torch.Tensor, datapoints.Image))
@pytest.mark.parametrize("label_type", (torch.Tensor, list))
@pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite")) @pytest.mark.parametrize("data_augmentation", ("hflip", "lsj", "multiscale", "ssd", "ssdlite"))
@pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor)) @pytest.mark.parametrize("to_tensor", (transforms.ToTensor, transforms.ToImageTensor))
def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): @pytest.mark.parametrize("sanitize", (True, False))
def test_detection_preset(image_type, data_augmentation, to_tensor, sanitize):
torch.manual_seed(0)
if data_augmentation == "hflip": if data_augmentation == "hflip":
t = [ t = [
transforms.RandomHorizontalFlip(p=1), transforms.RandomHorizontalFlip(p=1),
...@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): ...@@ -2290,20 +2288,20 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
t = [ t = [
transforms.RandomPhotometricDistort(p=1), transforms.RandomPhotometricDistort(p=1),
transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})), transforms.RandomZoomOut(fill=defaultdict(lambda: (123.0, 117.0, 104.0), {datapoints.Mask: 0})),
# TODO: put back IoUCrop once we remove its hard requirement for Labels transforms.RandomIoUCrop(),
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1), transforms.RandomHorizontalFlip(p=1),
to_tensor(), to_tensor(),
transforms.ConvertImageDtype(torch.float), transforms.ConvertImageDtype(torch.float),
] ]
elif data_augmentation == "ssdlite": elif data_augmentation == "ssdlite":
t = [ t = [
# TODO: put back IoUCrop once we remove its hard requirement for Labels transforms.RandomIoUCrop(),
# transforms.RandomIoUCrop(),
transforms.RandomHorizontalFlip(p=1), transforms.RandomHorizontalFlip(p=1),
to_tensor(), to_tensor(),
transforms.ConvertImageDtype(torch.float), transforms.ConvertImageDtype(torch.float),
] ]
if sanitize:
t += [transforms.SanitizeBoundingBoxes()]
t = transforms.Compose(t) t = transforms.Compose(t)
num_boxes = 5 num_boxes = 5
...@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): ...@@ -2317,10 +2315,7 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
assert is_simple_tensor(image) assert is_simple_tensor(image)
label = torch.randint(0, 10, size=(num_boxes,)) label = torch.randint(0, 10, size=(num_boxes,))
if label_type is list:
label = label.tolist()
# TODO: is the shape of the boxes OK? Should it be (1, num_boxes, 4)?? Same for masks
boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4)) boxes = torch.randint(0, min(H, W) // 2, size=(num_boxes, 4))
boxes[:, 2:] += boxes[:, :2] boxes[:, 2:] += boxes[:, :2]
boxes = boxes.clamp(min=0, max=min(H, W)) boxes = boxes.clamp(min=0, max=min(H, W))
...@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor): ...@@ -2343,8 +2338,19 @@ def test_detection_preset(image_type, label_type, data_augmentation, to_tensor):
assert isinstance(out["image"], datapoints.Image) assert isinstance(out["image"], datapoints.Image)
assert isinstance(out["label"], type(sample["label"])) assert isinstance(out["label"], type(sample["label"]))
out["label"] = torch.tensor(out["label"]) num_boxes_expected = {
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes # ssd and ssdlite contain RandomIoUCrop which may "remove" some bbox. It
# doesn't remove them strictly speaking, it just marks some boxes as
# degenerate and those boxes will be later removed by
# SanitizeBoundingBoxes(), which we add to the pipelines if the sanitize
# param is True.
# Note that the values below are probably specific to the random seed
# set above (which is fine).
(True, "ssd"): 4,
(True, "ssdlite"): 4,
}.get((sanitize, data_augmentation), num_boxes)
assert out["boxes"].shape[0] == out["masks"].shape[0] == out["label"].shape[0] == num_boxes_expected
@pytest.mark.parametrize("min_size", (1, 10)) @pytest.mark.parametrize("min_size", (1, 10))
...@@ -2377,7 +2383,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): ...@@ -2377,7 +2383,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid] valid_indices = [i for (i, is_valid) in enumerate(is_valid_mask) if is_valid]
boxes = torch.tensor(boxes) boxes = torch.tensor(boxes)
labels = torch.arange(boxes.shape[-2]) labels = torch.arange(boxes.shape[0])
boxes = datapoints.BoundingBox( boxes = datapoints.BoundingBox(
boxes, boxes,
...@@ -2385,12 +2391,15 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): ...@@ -2385,12 +2391,15 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
spatial_size=(H, W), spatial_size=(H, W),
) )
masks = datapoints.Mask(torch.randint(0, 2, size=(boxes.shape[0], H, W)))
sample = { sample = {
"image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8), "image": torch.randint(0, 256, size=(1, 3, H, W), dtype=torch.uint8),
"labels": labels, "labels": labels,
"boxes": boxes, "boxes": boxes,
"whatever": torch.rand(10), "whatever": torch.rand(10),
"None": None, "None": None,
"masks": masks,
} }
out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample) out = transforms.SanitizeBoundingBoxes(min_size=min_size, labels_getter=labels_getter)(sample)
...@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter): ...@@ -2402,7 +2411,7 @@ def test_sanitize_bounding_boxes(min_size, labels_getter):
assert out["labels"] is sample["labels"] assert out["labels"] is sample["labels"]
else: else:
assert isinstance(out["labels"], torch.Tensor) assert isinstance(out["labels"], torch.Tensor)
assert out["boxes"].shape[:-1] == out["labels"].shape assert out["boxes"].shape[0] == out["labels"].shape[0] == out["masks"].shape[0]
# This works because we conveniently set labels to arange(num_boxes) # This works because we conveniently set labels to arange(num_boxes)
assert out["labels"].tolist() == valid_indices assert out["labels"].tolist() == valid_indices
......
...@@ -1090,13 +1090,16 @@ class TestRefDetTransforms: ...@@ -1090,13 +1090,16 @@ class TestRefDetTransforms:
"t_ref, t, data_kwargs", "t_ref, t, data_kwargs",
[ [
(det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}), (det_transforms.RandomHorizontalFlip(p=1.0), v2_transforms.RandomHorizontalFlip(p=1.0), {}),
# FIXME: make (
# v2_transforms.Compose([ det_transforms.RandomIoUCrop(),
# v2_transforms.RandomIoUCrop(), v2_transforms.Compose(
# v2_transforms.SanitizeBoundingBoxes() [
# ]) v2_transforms.RandomIoUCrop(),
# work v2_transforms.SanitizeBoundingBoxes(labels_getter=lambda sample: sample[1]["labels"]),
# (det_transforms.RandomIoUCrop(), v2_transforms.RandomIoUCrop(), {"with_mask": False}), ]
),
{"with_mask": False},
),
(det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}), (det_transforms.RandomZoomOut(), v2_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}), (det_transforms.ScaleJitter((1024, 1024)), v2_transforms.ScaleJitter((1024, 1024)), {}),
( (
......
...@@ -721,8 +721,6 @@ class RandomIoUCrop(Transform): ...@@ -721,8 +721,6 @@ class RandomIoUCrop(Transform):
if left == right or top == bottom: if left == right or top == bottom:
continue continue
# FIXME: I think we can stop here?
# check for any valid boxes with centers within the crop area # check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box( xyxy_bboxes = F.convert_format_bounding_box(
bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY bboxes.as_subclass(torch.Tensor), bboxes.format, datapoints.BoundingBoxFormat.XYXY
...@@ -745,23 +743,16 @@ class RandomIoUCrop(Transform): ...@@ -745,23 +743,16 @@ class RandomIoUCrop(Transform):
return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area) return dict(top=top, left=left, height=new_h, width=new_w, is_within_crop_area=is_within_crop_area)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# FIXME: refactor this to not remove anything
if len(params) < 1: if len(params) < 1:
return inpt return inpt
is_within_crop_area = params["is_within_crop_area"]
output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"]) output = F.crop(inpt, top=params["top"], left=params["left"], height=params["height"], width=params["width"])
if isinstance(output, datapoints.BoundingBox): if isinstance(output, datapoints.BoundingBox):
bboxes = output[is_within_crop_area] # We "mark" the invalid boxes as degenreate, and they can be
bboxes = F.clamp_bounding_box(bboxes, output.format, output.spatial_size) # removed by a later call to SanitizeBoundingBoxes()
output = datapoints.BoundingBox.wrap_like(output, bboxes) output[~params["is_within_crop_area"]] = 0
elif isinstance(output, datapoints.Mask):
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = datapoints.Mask.wrap_like(output, masks)
return output return output
......
...@@ -265,14 +265,14 @@ class SanitizeBoundingBoxes(Transform): ...@@ -265,14 +265,14 @@ class SanitizeBoundingBoxes(Transform):
), ),
) )
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1] ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
mask = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1) valid = (ws >= self.min_size) & (hs >= self.min_size) & (boxes >= 0).all(dim=-1)
# TODO: Do we really need to check for out of bounds here? All # TODO: Do we really need to check for out of bounds here? All
# transforms should be clamping anyway, so this should never happen? # transforms should be clamping anyway, so this should never happen?
image_h, image_w = boxes.spatial_size image_h, image_w = boxes.spatial_size
mask &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w) valid &= (boxes[:, 0] <= image_w) & (boxes[:, 2] <= image_w)
mask &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h) valid &= (boxes[:, 1] <= image_h) & (boxes[:, 3] <= image_h)
params = dict(mask=mask, labels=labels) params = dict(valid=valid, labels=labels)
flat_outputs = [ flat_outputs = [
# Even-though it may look like we're transforming all inputs, we don't: # Even-though it may look like we're transforming all inputs, we don't:
# _transform() will only care about BoundingBoxes and the labels # _transform() will only care about BoundingBoxes and the labels
...@@ -284,7 +284,9 @@ class SanitizeBoundingBoxes(Transform): ...@@ -284,7 +284,9 @@ class SanitizeBoundingBoxes(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if (inpt is not None and inpt is params["labels"]) or isinstance(inpt, datapoints.BoundingBox): if (inpt is not None and inpt is params["labels"]) or isinstance(
inpt = inpt[params["mask"]] inpt, (datapoints.BoundingBox, datapoints.Mask)
):
inpt = inpt[params["valid"]]
return inpt return inpt
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment