Unverified Commit d6444529 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

check full sample in query_bounding_box (#6484)

* check full sample in query_bounding_box

* support no bounding boxes in FixedSizeCrop

* fix test
parent 6746986d
...@@ -1487,6 +1487,7 @@ class TestFixedSizeCrop: ...@@ -1487,6 +1487,7 @@ class TestFixedSizeCrop:
left_sentinel = mocker.MagicMock() left_sentinel = mocker.MagicMock()
height_sentinel = mocker.MagicMock() height_sentinel = mocker.MagicMock()
width_sentinel = mocker.MagicMock() width_sentinel = mocker.MagicMock()
is_valid = mocker.MagicMock() if needs_crop else None
padding_sentinel = mocker.MagicMock() padding_sentinel = mocker.MagicMock()
mocker.patch( mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params", "torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
...@@ -1496,6 +1497,7 @@ class TestFixedSizeCrop: ...@@ -1496,6 +1497,7 @@ class TestFixedSizeCrop:
left=left_sentinel, left=left_sentinel,
height=height_sentinel, height=height_sentinel,
width=width_sentinel, width=width_sentinel,
is_valid=is_valid,
padding=padding_sentinel, padding=padding_sentinel,
needs_pad=needs_pad, needs_pad=needs_pad,
), ),
......
...@@ -789,8 +789,12 @@ class FixedSizeCrop(Transform): ...@@ -789,8 +789,12 @@ class FixedSizeCrop(Transform):
top = int(offset_height * r) top = int(offset_height * r)
left = int(offset_width * r) left = int(offset_width * r)
if needs_crop: try:
bounding_boxes = query_bounding_box(sample) bounding_boxes = query_bounding_box(sample)
except ValueError:
bounding_boxes = None
if needs_crop and bounding_boxes is not None:
bounding_boxes = cast( bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width) features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
) )
...@@ -830,6 +834,8 @@ class FixedSizeCrop(Transform): ...@@ -830,6 +834,8 @@ class FixedSizeCrop(Transform):
height=params["height"], height=params["height"],
width=params["width"], width=params["width"],
) )
if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)): if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox): elif isinstance(inpt, features.BoundingBox):
...@@ -845,13 +851,14 @@ class FixedSizeCrop(Transform): ...@@ -845,13 +851,14 @@ class FixedSizeCrop(Transform):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0] sample = inputs if len(inputs) > 1 else inputs[0]
if not (
has_all(sample, features.BoundingBox) if not has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor):
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor) raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
and has_any(sample, features.Label, features.OneHotLabel)
): if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " f"If a BoundingBox is contained in the input sample, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
) )
return super().forward(sample) return super().forward(sample)
...@@ -11,11 +11,12 @@ from .functional._meta import get_dimensions_image_pil, get_dimensions_image_ten ...@@ -11,11 +11,12 @@ from .functional._meta import get_dimensions_image_pil, get_dimensions_image_ten
def query_bounding_box(sample: Any) -> features.BoundingBox: def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample) flat_sample, _ = tree_flatten(sample)
for i in flat_sample: bounding_boxes = {item for item in flat_sample if isinstance(item, features.BoundingBox)}
if isinstance(i, features.BoundingBox): if not bounding_boxes:
return i raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 2:
raise TypeError("No bounding box was found in the sample") raise ValueError("Found multiple bounding boxes in the sample")
return bounding_boxes.pop()
def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]: def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
...@@ -41,7 +42,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]: ...@@ -41,7 +42,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
if not chws: if not chws:
raise TypeError("No image was found in the sample") raise TypeError("No image was found in the sample")
elif len(chws) > 2: elif len(chws) > 2:
raise TypeError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}") raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop() return chws.pop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment