"vscode:/vscode.git/clone" did not exist on "38ae5a25da17015729d185db6e58b7baa0e095d0"
Unverified Commit 2c19af37 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix prototype transforms for `(*, H, W)` segmentation masks (#6574)

* add generator functions for segmentation masks

* update functional tests and fix kernels

* fix transforms tests
parent 52ecad8d
......@@ -178,7 +178,8 @@ def make_one_hot_labels(
yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
......@@ -186,14 +187,49 @@ def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=
return features.SegmentationMask(data)
def make_detection_masks(
*,
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, None),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_detection_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_detection_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dtype=torch.uint8):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, *size)
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks(
*,
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10),
num_categories=(1, 2, None),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_categories_ in itertools.product(dtypes, extra_dims, num_categories):
yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_)
def make_detection_and_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, None),
num_categories=(1, 2, None),
):
yield from make_detection_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_objects=num_objects)
yield from make_segmentation_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_categories=num_categories)
......@@ -10,6 +10,8 @@ from common_utils import assert_equal, cpu_and_gpu
from prototype_common_utils import (
make_bounding_box,
make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_mask,
make_image,
make_images,
make_label,
......@@ -62,6 +64,7 @@ def parametrize_from_transforms(*transforms):
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_detection_and_segmentation_masks,
]:
inputs = list(creation_fn())
try:
......@@ -131,7 +134,12 @@ class TestSmoke:
# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, segmentation masks and plain labels"
input_copy = dict(input)
for unsup_data in [make_label(), make_bounding_box(format="XYXY"), make_segmentation_mask()]:
for unsup_data in [
make_label(),
make_bounding_box(format="XYXY"),
make_detection_mask(),
make_segmentation_mask(),
]:
input_copy["unsupported"] = unsup_data
with pytest.raises(TypeError, match=err_msg):
transform(input_copy)
......@@ -233,7 +241,7 @@ class TestSmoke:
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
)
for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]:
for inpt in [make_bounding_box(format="XYXY"), make_detection_and_segmentation_masks()]:
output = transform(inpt)
assert output is inpt
......@@ -1206,7 +1214,7 @@ class TestRandomIoUCrop:
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
label = features.Label(torch.randint(0, 10, size=(6,)))
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_segmentation_mask((32, 24), num_objects=6)
masks = make_detection_mask((32, 24), num_objects=6)
sample = [image, bboxes, label, ohe_label, masks]
......@@ -1578,7 +1586,7 @@ class TestFixedSizeCrop:
bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,))
segmentation_masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1))
......
......@@ -8,7 +8,14 @@ import pytest
import torch.testing
import torchvision.prototype.transforms.functional as F
from common_utils import cpu_and_gpu
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks
from prototype_common_utils import (
ArgsKwargs,
make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_masks,
make_image,
make_images,
)
from torch import jit
from torchvision.prototype import features
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
......@@ -55,7 +62,7 @@ def horizontal_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask():
for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
yield ArgsKwargs(mask)
......@@ -73,7 +80,7 @@ def vertical_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask():
for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
yield ArgsKwargs(mask)
......@@ -118,7 +125,7 @@ def resize_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def resize_segmentation_mask():
for mask, max_size in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[None, 34], # max_size
):
height, width = mask.shape[-2:]
......@@ -173,7 +180,7 @@ def affine_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask():
for mask, angle, translate, scale, shear in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
......@@ -226,7 +233,7 @@ def rotate_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask():
for mask, angle, expand, center in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
......@@ -269,7 +276,7 @@ def crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
):
yield ArgsKwargs(
mask,
......@@ -307,7 +314,7 @@ def resized_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def resized_crop_segmentation_mask():
for mask, top, left, height, width, size in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
):
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)
......@@ -326,7 +333,7 @@ def pad_image_tensor():
@register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask():
for mask, padding, padding_mode in itertools.product(
make_segmentation_masks(),
make_detection_and_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
......@@ -374,7 +381,7 @@ def perspective_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def perspective_segmentation_mask():
for mask, perspective_coeffs in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))),
make_detection_and_segmentation_masks(extra_dims=((), (4,))),
[
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
......@@ -411,7 +418,7 @@ def elastic_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))):
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield ArgsKwargs(
......@@ -440,7 +447,7 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask():
for mask, output_size in itertools.product(
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
make_detection_and_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield ArgsKwargs(mask, output_size)
......@@ -771,7 +778,8 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
return expected_mask.to(mask.device)
for mask in make_segmentation_masks(extra_dims=((), (4,))):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.affine_segmentation_mask(
mask,
angle=angle,
......@@ -1011,7 +1019,8 @@ def test_correctness_rotate_segmentation_mask(angle, expand, center):
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
return expected_mask.to(mask.device)
for mask in make_segmentation_masks(extra_dims=((), (4,))):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.rotate_segmentation_mask(
mask,
angle=angle,
......@@ -1138,7 +1147,7 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width):
return expected
for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
if mask.device != torch.device(device):
mask = mask.to(device)
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
......@@ -1358,7 +1367,7 @@ def test_correctness_pad_segmentation_mask(padding, padding_mode):
return output
for mask in make_segmentation_masks():
for mask in make_detection_and_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode)
expected_mask = _compute_expected_mask(mask, padding, padding_mode)
......@@ -1487,7 +1496,8 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint
pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
for mask in make_segmentation_masks(extra_dims=((), (4,))):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
mask = mask.to(device)
output_mask = F.perspective_segmentation_mask(
......@@ -1649,14 +1659,18 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
"fn, make_samples",
[
(F.elastic_image_tensor, make_images),
# FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks.
(F.elastic_segmentation_mask, make_detection_masks),
],
)
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
in_box = [10, 15, 25, 35]
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
c, h, w = sample.shape[-3:]
# Setup a dummy image with 4 points
print(sample.shape)
sample[..., in_box[1], in_box[0]] = torch.arange(10, 10 + c)
sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c)
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.arange(30, 30 + c)
......
......@@ -3,7 +3,7 @@ import pytest
import torch
from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask
from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any
......@@ -12,7 +12,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size)
SEGMENTATION_MASK = make_detection_mask(size=IMAGE.image_size)
@pytest.mark.parametrize(
......
......@@ -367,15 +367,21 @@ def affine_bounding_box(
def affine_segmentation_mask(
mask: torch.Tensor,
segmentation_mask: torch.Tensor,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
return affine_image_tensor(
mask,
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = affine_image_tensor(
segmentation_mask,
angle=angle,
translate=translate,
scale=scale,
......@@ -384,6 +390,11 @@ def affine_segmentation_mask(
center=center,
)
if needs_squeeze:
output = output.squeeze(0)
return output
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
......@@ -522,19 +533,30 @@ def rotate_bounding_box(
def rotate_segmentation_mask(
img: torch.Tensor,
segmentation_mask: torch.Tensor,
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> torch.Tensor:
return rotate_image_tensor(
img,
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = rotate_image_tensor(
segmentation_mask,
angle=angle,
expand=expand,
interpolation=InterpolationMode.NEAREST,
center=center,
)
if needs_squeeze:
output = output.squeeze(0)
return output
def rotate(
inpt: DType,
......@@ -607,7 +629,18 @@ def _pad_with_vector_fill(
def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant"
) -> torch.Tensor:
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode)
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode)
if needs_squeeze:
output = output.squeeze(0)
return output
def pad_bounding_box(
......@@ -793,11 +826,22 @@ def perspective_bounding_box(
).view(original_shape)
def perspective_segmentation_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
return perspective_image_tensor(
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
def perspective_segmentation_mask(segmentation_mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = perspective_image_tensor(
segmentation_mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
)
if needs_squeeze:
output = output.squeeze(0)
return output
def perspective(
inpt: DType,
......@@ -880,8 +924,19 @@ def elastic_bounding_box(
).view(original_shape)
def elastic_segmentation_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
return elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
def elastic_segmentation_mask(segmentation_mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = elastic_image_tensor(segmentation_mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
if needs_squeeze:
output = output.squeeze(0)
return output
def elastic(
......@@ -973,7 +1028,18 @@ def center_crop_bounding_box(
def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
if needs_squeeze:
output = output.squeeze(0)
return output
def center_crop(inpt: DType, output_size: List[int]) -> DType:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment