Unverified Commit 9c3e2bf4 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

port `FixedSizeCrop` from detection references to prototype transforms (#6417)

* port `FixedSizeCrop` from detection references to prototype transforms

* mypy

* [skip ci] call invalid boxes and corresponding masks and labels

* cherry-pick missing functions from #6401

* fix feature wrapping

* add test

* mypy

* add input type restrictions

* add test for _get_params

* fix input checks
parent 96620011
...@@ -10,6 +10,7 @@ from common_utils import assert_equal, cpu_and_gpu ...@@ -10,6 +10,7 @@ from common_utils import assert_equal, cpu_and_gpu
from test_prototype_transforms_functional import ( from test_prototype_transforms_functional import (
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
make_image,
make_images, make_images,
make_label, make_label,
make_one_hot_labels, make_one_hot_labels,
...@@ -1328,3 +1329,161 @@ class TestRandomShortestSize: ...@@ -1328,3 +1329,161 @@ class TestRandomShortestSize:
transform(inpt_sentinel) transform(inpt_sentinel)
mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel) mock.assert_called_once_with(inpt_sentinel, size=size_sentinel, interpolation=interpolation_sentinel)
class TestFixedSizeCrop:
def test__get_params(self, mocker):
crop_size = (7, 7)
batch_shape = (10,)
image_size = (11, 5)
transform = transforms.FixedSizeCrop(size=crop_size)
sample = dict(
image=make_image(size=image_size, color_space=features.ColorSpace.RGB),
bounding_boxes=make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=batch_shape
),
)
params = transform._get_params(sample)
assert params["needs_crop"]
assert params["height"] <= crop_size[0]
assert params["width"] <= crop_size[1]
assert (
isinstance(params["is_valid"], torch.Tensor)
and params["is_valid"].dtype is torch.bool
and params["is_valid"].shape == batch_shape
)
assert params["needs_pad"]
assert any(pad > 0 for pad in params["padding"])
@pytest.mark.parametrize("needs", list(itertools.product((False, True), repeat=2)))
def test__transform(self, mocker, needs):
fill_sentinel = mocker.MagicMock()
padding_mode_sentinel = mocker.MagicMock()
transform = transforms.FixedSizeCrop((-1, -1), fill=fill_sentinel, padding_mode=padding_mode_sentinel)
transform._transformed_types = (mocker.MagicMock,)
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
needs_crop, needs_pad = needs
top_sentinel = mocker.MagicMock()
left_sentinel = mocker.MagicMock()
height_sentinel = mocker.MagicMock()
width_sentinel = mocker.MagicMock()
padding_sentinel = mocker.MagicMock()
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=needs_crop,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
padding=padding_sentinel,
needs_pad=needs_pad,
),
)
inpt_sentinel = mocker.MagicMock()
mock_crop = mocker.patch("torchvision.prototype.transforms._geometry.F.crop")
mock_pad = mocker.patch("torchvision.prototype.transforms._geometry.F.pad")
transform(inpt_sentinel)
if needs_crop:
mock_crop.assert_called_once_with(
inpt_sentinel,
top=top_sentinel,
left=left_sentinel,
height=height_sentinel,
width=width_sentinel,
)
else:
mock_crop.assert_not_called()
if needs_pad:
# If we cropped before, the input to F.pad is no longer inpt_sentinel. Thus, we can't use
# `MagicMock.assert_called_once_with` and have to perform the checks manually
mock_pad.assert_called_once()
args, kwargs = mock_pad.call_args
if not needs_crop:
assert args[0] is inpt_sentinel
assert args[1] is padding_sentinel
assert kwargs == dict(fill=fill_sentinel, padding_mode=padding_mode_sentinel)
else:
mock_pad.assert_not_called()
def test__transform_culling(self, mocker):
batch_size = 10
image_size = (10, 10)
is_valid = torch.randint(0, 2, (batch_size,), dtype=torch.bool)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=image_size[0],
width=image_size[1],
is_valid=is_valid,
needs_pad=False,
),
)
bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
output = transform(
dict(
bounding_boxes=bounding_boxes,
segmentation_masks=segmentation_masks,
labels=labels,
)
)
assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
assert_equal(output["segmentation_masks"], segmentation_masks[is_valid])
assert_equal(output["labels"], labels[is_valid])
def test__transform_bounding_box_clamping(self, mocker):
batch_size = 3
image_size = (10, 10)
mocker.patch(
"torchvision.prototype.transforms._geometry.FixedSizeCrop._get_params",
return_value=dict(
needs_crop=True,
top=0,
left=0,
height=image_size[0],
width=image_size[1],
is_valid=torch.full((batch_size,), fill_value=True),
needs_pad=False,
),
)
bounding_box = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
mock = mocker.patch("torchvision.prototype.transforms._geometry.F.clamp_bounding_box")
transform = transforms.FixedSizeCrop((-1, -1))
mocker.patch("torchvision.prototype.transforms._geometry.has_all", return_value=True)
mocker.patch("torchvision.prototype.transforms._geometry.has_any", return_value=True)
transform(bounding_box)
mock.assert_called_once()
...@@ -20,6 +20,7 @@ from ._geometry import ( ...@@ -20,6 +20,7 @@ from ._geometry import (
CenterCrop, CenterCrop,
ElasticTransform, ElasticTransform,
FiveCrop, FiveCrop,
FixedSizeCrop,
Pad, Pad,
RandomAffine, RandomAffine,
RandomCrop, RandomCrop,
......
...@@ -783,3 +783,100 @@ class RandomShortestSize(Transform): ...@@ -783,3 +783,100 @@ class RandomShortestSize(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.resize(inpt, size=params["size"], interpolation=self.interpolation) return F.resize(inpt, size=params["size"], interpolation=self.interpolation)
class FixedSizeCrop(Transform):
def __init__(
self,
size: Union[int, Sequence[int]],
fill: Union[int, float, Sequence[int], Sequence[float]] = 0,
padding_mode: str = "constant",
) -> None:
super().__init__()
size = tuple(_setup_size(size, error_msg="Please provide only two dimensions (h, w) for size."))
self.crop_height = size[0]
self.crop_width = size[1]
self.fill = fill # TODO: Fill is currently respected only on PIL. Apply tensor patch.
self.padding_mode = padding_mode
def _get_params(self, sample: Any) -> Dict[str, Any]:
image = query_image(sample)
_, height, width = get_image_dimensions(image)
new_height = min(height, self.crop_height)
new_width = min(width, self.crop_width)
needs_crop = new_height != height or new_width != width
offset_height = max(height - self.crop_height, 0)
offset_width = max(width - self.crop_width, 0)
r = torch.rand(1)
top = int(offset_height * r)
left = int(offset_width * r)
if needs_crop:
bounding_boxes = query_bounding_box(sample)
bounding_boxes = cast(
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=height, width=width)
)
bounding_boxes = features.BoundingBox.new_like(
bounding_boxes,
F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, image_size=bounding_boxes.image_size
),
)
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1)
else:
is_valid = None
pad_bottom = max(self.crop_height - new_height, 0)
pad_right = max(self.crop_width - new_width, 0)
needs_pad = pad_bottom != 0 or pad_right != 0
return dict(
needs_crop=needs_crop,
top=top,
left=left,
height=new_height,
width=new_width,
is_valid=is_valid,
padding=[0, 0, pad_right, pad_bottom],
needs_pad=needs_pad,
)
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if params["needs_crop"]:
inpt = F.crop(
inpt,
top=params["top"],
left=params["left"],
height=params["height"],
width=params["width"],
)
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.new_like(
inpt,
F.clamp_bounding_box(inpt[params["is_valid"]], format=inpt.format, image_size=inpt.image_size),
)
if params["needs_pad"]:
inpt = F.pad(inpt, params["padding"], fill=self.fill, padding_mode=self.padding_mode)
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if not (
has_all(sample, features.BoundingBox)
and has_any(sample, PIL.Image.Image, features.Image, is_simple_tensor)
and has_any(sample, features.Label, features.OneHotLabel)
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
)
return super().forward(sample)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment