Unverified Commit 3a2631ba authored by Federico Pozzi's avatar Federico Pozzi Committed by GitHub
Browse files

feat: add functional center crop on mask (#5961)



* feat: add functional center crop on mask

* test: add correctness center crop with random segmentation mask

* test: improvements

* test: improvements

* Apply suggestions from code review
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
Co-authored-by: default avatarFederico Pozzi <federico.pozzi@argo.vision>
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarPhilip Meier <github.pmeier@posteo.de>
parent 49496c4f
...@@ -10,11 +10,11 @@ from common_utils import cpu_and_gpu ...@@ -10,11 +10,11 @@ from common_utils import cpu_and_gpu
from torch import jit from torch import jit
from torch.nn.functional import one_hot from torch.nn.functional import one_hot
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
from torchvision.transforms.functional import _get_perspective_coeffs from torchvision.transforms.functional import _get_perspective_coeffs
from torchvision.transforms.functional_tensor import _max_value as get_max_value from torchvision.transforms.functional_tensor import _max_value as get_max_value
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu") make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
...@@ -421,6 +421,14 @@ def center_crop_bounding_box(): ...@@ -421,6 +421,14 @@ def center_crop_bounding_box():
) )
def center_crop_segmentation_mask():
for mask, output_size in itertools.product(
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield SampleInput(mask, output_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kernel", "kernel",
[ [
...@@ -1337,3 +1345,26 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1337,3 +1345,26 @@ def test_correctness_center_crop_bounding_box(device, output_size):
else: else:
expected_bboxes = expected_bboxes[0] expected_bboxes = expected_bboxes[0]
torch.testing.assert_close(output_boxes, expected_bboxes) torch.testing.assert_close(output_boxes, expected_bboxes)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
def test_correctness_center_crop_segmentation_mask(device, output_size):
def _compute_expected_segmentation_mask(mask, output_size):
crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]
_, image_height, image_width = mask.shape
if crop_width > image_height or crop_height > image_width:
padding = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
mask = F.pad_image_tensor(mask, padding, fill=0)
left = round((image_width - crop_width) * 0.5)
top = round((image_height - crop_height) * 0.5)
return mask[:, top : top + crop_height, left : left + crop_width]
mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
actual = F.center_crop_segmentation_mask(mask, output_size)
expected = _compute_expected_segmentation_mask(mask, output_size)
torch.testing.assert_close(expected, actual)
...@@ -46,6 +46,7 @@ from ._geometry import ( ...@@ -46,6 +46,7 @@ from ._geometry import (
resize_image_pil, resize_image_pil,
resize_segmentation_mask, resize_segmentation_mask,
center_crop_bounding_box, center_crop_bounding_box,
center_crop_segmentation_mask,
center_crop_image_tensor, center_crop_image_tensor,
center_crop_image_pil, center_crop_image_pil,
resized_crop_bounding_box, resized_crop_bounding_box,
......
...@@ -630,6 +630,10 @@ def center_crop_bounding_box( ...@@ -630,6 +630,10 @@ def center_crop_bounding_box(
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left) return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)
def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
def resized_crop_image_tensor( def resized_crop_image_tensor(
img: torch.Tensor, img: torch.Tensor,
top: int, top: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment