Unverified Commit 76144bad authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

prevent unwrapping in SanitizeBoundingBoxes (#7446)

parent 995f9b95
...@@ -2020,6 +2020,9 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type): ...@@ -2020,6 +2020,9 @@ def test_sanitize_bounding_boxes(min_size, labels_getter, sample_type):
assert out_image is input_img assert out_image is input_img
assert out_whatever is whatever assert out_whatever is whatever
assert isinstance(out_boxes, datapoints.BoundingBox)
assert isinstance(out_masks, datapoints.Mask)
if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None): if labels_getter is None or (callable(labels_getter) and labels_getter({"labels": "blah"}) is None):
assert out_labels is labels assert out_labels is labels
else: else:
......
...@@ -397,10 +397,15 @@ class SanitizeBoundingBox(Transform): ...@@ -397,10 +397,15 @@ class SanitizeBoundingBox(Transform):
return tree_unflatten(flat_outputs, spec) return tree_unflatten(flat_outputs, spec)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
is_label = inpt is not None and inpt is params["labels"]
is_bounding_box_or_mask = isinstance(inpt, (datapoints.BoundingBox, datapoints.Mask))
if (inpt is not None and inpt is params["labels"]) or isinstance( if not (is_label or is_bounding_box_or_mask):
inpt, (datapoints.BoundingBox, datapoints.Mask)
):
inpt = inpt[params["valid"]]
return inpt return inpt
output = inpt[params["valid"]]
if is_label:
return output
return type(inpt).wrap_like(inpt, output)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment