Unverified Commit d6444529 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

check full sample in query_bounding_box (#6484)

* check full sample in query_bounding_box

* support no bounding boxes in FixedSizeCrop

* fix test
parent 6746986d
......@@ -1487,6 +1487,7 @@ class TestFixedSizeCrop:
left_sentinel = mocker.MagicMock()
height_sentinel = mocker.MagicMock()
width_sentinel = mocker.MagicMock()
is_valid = mocker.MagicMock() if needs_crop else None
padding_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
......@@ -1496,6 +1497,7 @@ class TestFixedSizeCrop:
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
is_valid=is_valid,
padding=padding_sentinel,
needs_pad=needs_pad,
),
......
......@@ -789,8 +789,12 @@ class FixedSizeCrop(Transform):
top = int(offset_height * r)
left = int(offset_width * r)
if needs_crop:
try:
bounding_boxes = query_bounding_box(sample)
except ValueError:
bounding_boxes = None
if needs_crop and bounding_boxes is not None:
bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
)
......@@ -830,6 +834,8 @@ class FixedSizeCrop(Transform):
height=params["height"],
width=params["width"],
)
if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox):
......@@ -845,13 +851,14 @@ class FixedSizeCrop(Transform):
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
if not has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor):
raise TypeError(f"{type(self).__name__}() requires input sample to contain an tensor or PIL image.")
if has_any(sample, features.BoundingBox) and not has_any(sample, features.Label, features.OneHotLabel):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
f"If a BoundingBox is contained in the input sample, "
f"{type(self).__name__}() also requires it to contain a Label or OneHotLabel."
)
return super().forward(sample)
......@@ -11,11 +11,12 @@ from .functional._meta import get_dimensions_image_pil, get_dimensions_image_ten
def query_bounding_box(sample: Any) -> features.BoundingBox:
flat_sample, _ = tree_flatten(sample)
for i in flat_sample:
if isinstance(i, features.BoundingBox):
return i
raise TypeError("No bounding box was found in the sample")
bounding_boxes = {item for item in flat_sample if isinstance(item, features.BoundingBox)}
if not bounding_boxes:
raise TypeError("No bounding box was found in the sample")
elif len(bounding_boxes) > 2:
raise ValueError("Found multiple bounding boxes in the sample")
return bounding_boxes.pop()
def get_chw(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
......@@ -41,7 +42,7 @@ def query_chw(sample: Any) -> Tuple[int, int, int]:
if not chws:
raise TypeError("No image was found in the sample")
elif len(chws) > 2:
raise TypeError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
return chws.pop()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment