Unverified Commit 2c19af37 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Fix prototype transforms for `(*, H, W)` segmentation masks (#6574)

* add generator functions for segmentation masks

* update functional tests and fix kernels

* fix transforms tests
parent 52ecad8d
...@@ -178,7 +178,8 @@ def make_one_hot_labels( ...@@ -178,7 +178,8 @@ def make_one_hot_labels(
yield make_one_hot_label(extra_dims_) yield make_one_hot_label(extra_dims_)
def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8): def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=torch.uint8):
# This produces "detection" masks, i.e. `(*, N, H, W)`, where `N` denotes the number of objects
size = size if size is not None else torch.randint(16, 33, (2,)).tolist() size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ())) num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size) shape = (*extra_dims, num_objects, *size)
...@@ -186,14 +187,49 @@ def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype= ...@@ -186,14 +187,49 @@ def make_segmentation_mask(size=None, *, num_objects=None, extra_dims=(), dtype=
return features.SegmentationMask(data) return features.SegmentationMask(data)
def make_detection_masks(
*,
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, None),
):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_detection_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects):
yield make_detection_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_)
def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dtype=torch.uint8):
# This produces "segmentation" masks, i.e. `(*, H, W)`, where the category is encoded in the values
size = size if size is not None else torch.randint(16, 33, (2,)).tolist()
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, *size)
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
return features.SegmentationMask(data)
def make_segmentation_masks( def make_segmentation_masks(
*,
sizes=((16, 16), (7, 33), (31, 9)), sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, 10), num_categories=(1, 2, None),
): ):
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims): for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_) yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)
for dtype, extra_dims_, num_objects_ in itertools.product(dtypes, extra_dims, num_objects): for dtype, extra_dims_, num_categories_ in itertools.product(dtypes, extra_dims, num_categories):
yield make_segmentation_mask(size=sizes[0], num_objects=num_objects_, dtype=dtype, extra_dims=extra_dims_) yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_)
def make_detection_and_segmentation_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
num_objects=(1, 0, None),
num_categories=(1, 2, None),
):
yield from make_detection_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_objects=num_objects)
yield from make_segmentation_masks(sizes=sizes, dtypes=dtypes, extra_dims=extra_dims, num_categories=num_categories)
...@@ -10,6 +10,8 @@ from common_utils import assert_equal, cpu_and_gpu ...@@ -10,6 +10,8 @@ from common_utils import assert_equal, cpu_and_gpu
from prototype_common_utils import ( from prototype_common_utils import (
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_mask,
make_image, make_image,
make_images, make_images,
make_label, make_label,
...@@ -62,6 +64,7 @@ def parametrize_from_transforms(*transforms): ...@@ -62,6 +64,7 @@ def parametrize_from_transforms(*transforms):
make_one_hot_labels, make_one_hot_labels,
make_vanilla_tensor_images, make_vanilla_tensor_images,
make_pil_images, make_pil_images,
make_detection_and_segmentation_masks,
]: ]:
inputs = list(creation_fn()) inputs = list(creation_fn())
try: try:
...@@ -131,7 +134,12 @@ class TestSmoke: ...@@ -131,7 +134,12 @@ class TestSmoke:
# Check if we raise an error if sample contains bbox or mask or label # Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, segmentation masks and plain labels" err_msg = "does not support bounding boxes, segmentation masks and plain labels"
input_copy = dict(input) input_copy = dict(input)
for unsup_data in [make_label(), make_bounding_box(format="XYXY"), make_segmentation_mask()]: for unsup_data in [
make_label(),
make_bounding_box(format="XYXY"),
make_detection_mask(),
make_segmentation_mask(),
]:
input_copy["unsupported"] = unsup_data input_copy["unsupported"] = unsup_data
with pytest.raises(TypeError, match=err_msg): with pytest.raises(TypeError, match=err_msg):
transform(input_copy) transform(input_copy)
...@@ -233,7 +241,7 @@ class TestSmoke: ...@@ -233,7 +241,7 @@ class TestSmoke:
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
) )
for inpt in [make_bounding_box(format="XYXY"), make_segmentation_mask()]: for inpt in [make_bounding_box(format="XYXY"), make_detection_and_segmentation_masks()]:
output = transform(inpt) output = transform(inpt)
assert output is inpt assert output is inpt
...@@ -1206,7 +1214,7 @@ class TestRandomIoUCrop: ...@@ -1206,7 +1214,7 @@ class TestRandomIoUCrop:
bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,)) bboxes = make_bounding_box(format="XYXY", image_size=(32, 24), extra_dims=(6,))
label = features.Label(torch.randint(0, 10, size=(6,))) label = features.Label(torch.randint(0, 10, size=(6,)))
ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1)) ohe_label = features.OneHotLabel(torch.zeros(6, 10).scatter_(1, label.unsqueeze(1), 1))
masks = make_segmentation_mask((32, 24), num_objects=6) masks = make_detection_mask((32, 24), num_objects=6)
sample = [image, bboxes, label, ohe_label, masks] sample = [image, bboxes, label, ohe_label, masks]
...@@ -1578,7 +1586,7 @@ class TestFixedSizeCrop: ...@@ -1578,7 +1586,7 @@ class TestFixedSizeCrop:
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
) )
segmentation_masks = make_segmentation_mask(size=image_size, extra_dims=(batch_size,)) segmentation_masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,)) labels = make_label(size=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
......
...@@ -8,7 +8,14 @@ import pytest ...@@ -8,7 +8,14 @@ import pytest
import torch.testing import torch.testing
import torchvision.prototype.transforms.functional as F import torchvision.prototype.transforms.functional as F
from common_utils import cpu_and_gpu from common_utils import cpu_and_gpu
from prototype_common_utils import ArgsKwargs, make_bounding_boxes, make_image, make_images, make_segmentation_masks from prototype_common_utils import (
ArgsKwargs,
make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_masks,
make_image,
make_images,
)
from torch import jit from torch import jit
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding from torchvision.prototype.transforms.functional._geometry import _center_crop_compute_padding
...@@ -55,7 +62,7 @@ def horizontal_flip_bounding_box(): ...@@ -55,7 +62,7 @@ def horizontal_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask(): def horizontal_flip_segmentation_mask():
for mask in make_segmentation_masks(): for mask in make_detection_and_segmentation_masks():
yield ArgsKwargs(mask) yield ArgsKwargs(mask)
...@@ -73,7 +80,7 @@ def vertical_flip_bounding_box(): ...@@ -73,7 +80,7 @@ def vertical_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask(): def vertical_flip_segmentation_mask():
for mask in make_segmentation_masks(): for mask in make_detection_and_segmentation_masks():
yield ArgsKwargs(mask) yield ArgsKwargs(mask)
...@@ -118,7 +125,7 @@ def resize_bounding_box(): ...@@ -118,7 +125,7 @@ def resize_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def resize_segmentation_mask(): def resize_segmentation_mask():
for mask, max_size in itertools.product( for mask, max_size in itertools.product(
make_segmentation_masks(), make_detection_and_segmentation_masks(),
[None, 34], # max_size [None, 34], # max_size
): ):
height, width = mask.shape[-2:] height, width = mask.shape[-2:]
...@@ -173,7 +180,7 @@ def affine_bounding_box(): ...@@ -173,7 +180,7 @@ def affine_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask(): def affine_segmentation_mask():
for mask, angle, translate, scale, shear in itertools.product( for mask, angle, translate, scale, shear in itertools.product(
make_segmentation_masks(), make_detection_and_segmentation_masks(),
[-87, 15, 90], # angle [-87, 15, 90], # angle
[5, -5], # translate [5, -5], # translate
[0.77, 1.27], # scale [0.77, 1.27], # scale
...@@ -226,7 +233,7 @@ def rotate_bounding_box(): ...@@ -226,7 +233,7 @@ def rotate_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask(): def rotate_segmentation_mask():
for mask, angle, expand, center in itertools.product( for mask, angle, expand, center in itertools.product(
make_segmentation_masks(), make_detection_and_segmentation_masks(),
[-87, 15, 90], # angle [-87, 15, 90], # angle
[True, False], # expand [True, False], # expand
[None, [12, 23]], # center [None, [12, 23]], # center
...@@ -269,7 +276,7 @@ def crop_bounding_box(): ...@@ -269,7 +276,7 @@ def crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def crop_segmentation_mask(): def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product( for mask, top, left, height, width in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20] make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
): ):
yield ArgsKwargs( yield ArgsKwargs(
mask, mask,
...@@ -307,7 +314,7 @@ def resized_crop_bounding_box(): ...@@ -307,7 +314,7 @@ def resized_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def resized_crop_segmentation_mask(): def resized_crop_segmentation_mask():
for mask, top, left, height, width, size in itertools.product( for mask, top, left, height, width, size in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)] make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
): ):
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size) yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)
...@@ -326,7 +333,7 @@ def pad_image_tensor(): ...@@ -326,7 +333,7 @@ def pad_image_tensor():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask(): def pad_segmentation_mask():
for mask, padding, padding_mode in itertools.product( for mask, padding, padding_mode in itertools.product(
make_segmentation_masks(), make_detection_and_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding [[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode, ["constant", "symmetric", "edge", "reflect"], # padding mode,
): ):
...@@ -374,7 +381,7 @@ def perspective_bounding_box(): ...@@ -374,7 +381,7 @@ def perspective_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def perspective_segmentation_mask(): def perspective_segmentation_mask():
for mask, perspective_coeffs in itertools.product( for mask, perspective_coeffs in itertools.product(
make_segmentation_masks(extra_dims=((), (4,))), make_detection_and_segmentation_masks(extra_dims=((), (4,))),
[ [
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
...@@ -411,7 +418,7 @@ def elastic_bounding_box(): ...@@ -411,7 +418,7 @@ def elastic_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask(): def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))): for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:] h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2) displacement = torch.rand(1, h, w, 2)
yield ArgsKwargs( yield ArgsKwargs(
...@@ -440,7 +447,7 @@ def center_crop_bounding_box(): ...@@ -440,7 +447,7 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask(): def center_crop_segmentation_mask():
for mask, output_size in itertools.product( for mask, output_size in itertools.product(
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), make_detection_and_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
): ):
yield ArgsKwargs(mask, output_size) yield ArgsKwargs(mask, output_size)
...@@ -771,7 +778,8 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce ...@@ -771,7 +778,8 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x] expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
return expected_mask.to(mask.device) return expected_mask.to(mask.device)
for mask in make_segmentation_masks(extra_dims=((), (4,))): # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.affine_segmentation_mask( output_mask = F.affine_segmentation_mask(
mask, mask,
angle=angle, angle=angle,
...@@ -1011,7 +1019,8 @@ def test_correctness_rotate_segmentation_mask(angle, expand, center): ...@@ -1011,7 +1019,8 @@ def test_correctness_rotate_segmentation_mask(angle, expand, center):
expected_mask[i, out_y, out_x] = mask[i, in_y, in_x] expected_mask[i, out_y, out_x] = mask[i, in_y, in_x]
return expected_mask.to(mask.device) return expected_mask.to(mask.device)
for mask in make_segmentation_masks(extra_dims=((), (4,))): # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.rotate_segmentation_mask( output_mask = F.rotate_segmentation_mask(
mask, mask,
angle=angle, angle=angle,
...@@ -1138,7 +1147,7 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width): ...@@ -1138,7 +1147,7 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width):
return expected return expected
for mask in make_segmentation_masks(): for mask in make_detection_and_segmentation_masks():
if mask.device != torch.device(device): if mask.device != torch.device(device):
mask = mask.to(device) mask = mask.to(device)
output_mask = F.crop_segmentation_mask(mask, top, left, height, width) output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
...@@ -1358,7 +1367,7 @@ def test_correctness_pad_segmentation_mask(padding, padding_mode): ...@@ -1358,7 +1367,7 @@ def test_correctness_pad_segmentation_mask(padding, padding_mode):
return output return output
for mask in make_segmentation_masks(): for mask in make_detection_and_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode) out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode)
expected_mask = _compute_expected_mask(mask, padding, padding_mode) expected_mask = _compute_expected_mask(mask, padding, padding_mode)
...@@ -1487,7 +1496,8 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint ...@@ -1487,7 +1496,8 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint
pcoeffs = _get_perspective_coeffs(startpoints, endpoints) pcoeffs = _get_perspective_coeffs(startpoints, endpoints)
for mask in make_segmentation_masks(extra_dims=((), (4,))): # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
mask = mask.to(device) mask = mask.to(device)
output_mask = F.perspective_segmentation_mask( output_mask = F.perspective_segmentation_mask(
...@@ -1649,14 +1659,18 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s ...@@ -1649,14 +1659,18 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)] "fn, make_samples",
[
(F.elastic_image_tensor, make_images),
# FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks.
(F.elastic_segmentation_mask, make_detection_masks),
],
) )
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
in_box = [10, 15, 25, 35] in_box = [10, 15, 25, 35]
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))): for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
c, h, w = sample.shape[-3:] c, h, w = sample.shape[-3:]
# Setup a dummy image with 4 points # Setup a dummy image with 4 points
print(sample.shape)
sample[..., in_box[1], in_box[0]] = torch.arange(10, 10 + c) sample[..., in_box[1], in_box[0]] = torch.arange(10, 10 + c)
sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c) sample[..., in_box[3] - 1, in_box[0]] = torch.arange(20, 20 + c)
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.arange(30, 30 + c) sample[..., in_box[3] - 1, in_box[2] - 1] = torch.arange(30, 30 + c)
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import torch import torch
from prototype_common_utils import make_bounding_box, make_image, make_segmentation_mask from prototype_common_utils import make_bounding_box, make_detection_mask, make_image
from torchvision.prototype import features from torchvision.prototype import features
from torchvision.prototype.transforms._utils import has_all, has_any from torchvision.prototype.transforms._utils import has_all, has_any
...@@ -12,7 +12,7 @@ from torchvision.prototype.transforms.functional import to_image_pil ...@@ -12,7 +12,7 @@ from torchvision.prototype.transforms.functional import to_image_pil
IMAGE = make_image(color_space=features.ColorSpace.RGB) IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_segmentation_mask(size=IMAGE.image_size) SEGMENTATION_MASK = make_detection_mask(size=IMAGE.image_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
...@@ -367,15 +367,21 @@ def affine_bounding_box( ...@@ -367,15 +367,21 @@ def affine_bounding_box(
def affine_segmentation_mask( def affine_segmentation_mask(
mask: torch.Tensor, segmentation_mask: torch.Tensor,
angle: float, angle: float,
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return affine_image_tensor( if segmentation_mask.ndim < 3:
mask, segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = affine_image_tensor(
segmentation_mask,
angle=angle, angle=angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
...@@ -384,6 +390,11 @@ def affine_segmentation_mask( ...@@ -384,6 +390,11 @@ def affine_segmentation_mask(
center=center, center=center,
) )
if needs_squeeze:
output = output.squeeze(0)
return output
def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]: def _convert_fill_arg(fill: Optional[Union[int, float, Sequence[int], Sequence[float]]]) -> Optional[List[float]]:
# Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517 # Fill = 0 is not equivalent to None, https://github.com/pytorch/vision/issues/6517
...@@ -522,19 +533,30 @@ def rotate_bounding_box( ...@@ -522,19 +533,30 @@ def rotate_bounding_box(
def rotate_segmentation_mask( def rotate_segmentation_mask(
img: torch.Tensor, segmentation_mask: torch.Tensor,
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
return rotate_image_tensor( if segmentation_mask.ndim < 3:
img, segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = rotate_image_tensor(
segmentation_mask,
angle=angle, angle=angle,
expand=expand, expand=expand,
interpolation=InterpolationMode.NEAREST, interpolation=InterpolationMode.NEAREST,
center=center, center=center,
) )
if needs_squeeze:
output = output.squeeze(0)
return output
def rotate( def rotate(
inpt: DType, inpt: DType,
...@@ -607,7 +629,18 @@ def _pad_with_vector_fill( ...@@ -607,7 +629,18 @@ def _pad_with_vector_fill(
def pad_segmentation_mask( def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant" segmentation_mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant"
) -> torch.Tensor: ) -> torch.Tensor:
return pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode) if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode)
if needs_squeeze:
output = output.squeeze(0)
return output
def pad_bounding_box( def pad_bounding_box(
...@@ -793,11 +826,22 @@ def perspective_bounding_box( ...@@ -793,11 +826,22 @@ def perspective_bounding_box(
).view(original_shape) ).view(original_shape)
def perspective_segmentation_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: def perspective_segmentation_mask(segmentation_mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
return perspective_image_tensor( if segmentation_mask.ndim < 3:
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = perspective_image_tensor(
segmentation_mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
) )
if needs_squeeze:
output = output.squeeze(0)
return output
def perspective( def perspective(
inpt: DType, inpt: DType,
...@@ -880,8 +924,19 @@ def elastic_bounding_box( ...@@ -880,8 +924,19 @@ def elastic_bounding_box(
).view(original_shape) ).view(original_shape)
def elastic_segmentation_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor: def elastic_segmentation_mask(segmentation_mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
return elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST) if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = elastic_image_tensor(segmentation_mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
if needs_squeeze:
output = output.squeeze(0)
return output
def elastic( def elastic(
...@@ -973,7 +1028,18 @@ def center_crop_bounding_box( ...@@ -973,7 +1028,18 @@ def center_crop_bounding_box(
def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
return center_crop_image_tensor(img=segmentation_mask, output_size=output_size) if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
if needs_squeeze:
output = output.squeeze(0)
return output
def center_crop(inpt: DType, output_size: List[int]) -> DType: def center_crop(inpt: DType, output_size: List[int]) -> DType:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment