Unverified Commit 1ea73f58 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Rename features.SegmentationMask to features.Mask (#6579)

* rename features.SegmentationMask -> features.Mask

* rename kernels *_segmentation_mask -> *_mask and cleanup input name

* cleanup

* rename module _segmentation_mask.py -> _mask.py

* fix test
parent f007a5e1
...@@ -184,7 +184,7 @@ def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=tor ...@@ -184,7 +184,7 @@ def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=tor
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ())) num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size) shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype) data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data) return features.Mask(data)
def make_detection_masks( def make_detection_masks(
...@@ -207,7 +207,7 @@ def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dty ...@@ -207,7 +207,7 @@ def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dty
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ())) num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, *size) shape = (*extra_dims, *size)
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype) data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
return features.SegmentationMask(data) return features.Mask(data)
def make_segmentation_masks( def make_segmentation_masks(
...@@ -224,7 +224,7 @@ def make_segmentation_masks( ...@@ -224,7 +224,7 @@ def make_segmentation_masks(
yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_) yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_)
def make_detection_and_segmentation_masks( def make_masks(
sizes=((16, 16), (7, 33), (31, 9)), sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,), dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)), extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
......
...@@ -10,11 +10,11 @@ from common_utils import assert_equal, cpu_and_gpu ...@@ -10,11 +10,11 @@ from common_utils import assert_equal, cpu_and_gpu
from prototype_common_utils import ( from prototype_common_utils import (
make_bounding_box, make_bounding_box,
make_bounding_boxes, make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_mask, make_detection_mask,
make_image, make_image,
make_images, make_images,
make_label, make_label,
make_masks,
make_one_hot_labels, make_one_hot_labels,
make_segmentation_mask, make_segmentation_mask,
) )
...@@ -64,7 +64,7 @@ def parametrize_from_transforms(*transforms): ...@@ -64,7 +64,7 @@ def parametrize_from_transforms(*transforms):
make_one_hot_labels, make_one_hot_labels,
make_vanilla_tensor_images, make_vanilla_tensor_images,
make_pil_images, make_pil_images,
make_detection_and_segmentation_masks, make_masks,
]: ]:
inputs = list(creation_fn()) inputs = list(creation_fn())
try: try:
...@@ -132,7 +132,7 @@ class TestSmoke: ...@@ -132,7 +132,7 @@ class TestSmoke:
transform(input_copy) transform(input_copy)
# Check if we raise an error if sample contains bbox or mask or label # Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, segmentation masks and plain labels" err_msg = "does not support bounding boxes, masks and plain labels"
input_copy = dict(input) input_copy = dict(input)
for unsup_data in [ for unsup_data in [
make_label(), make_label(),
...@@ -241,7 +241,7 @@ class TestSmoke: ...@@ -241,7 +241,7 @@ class TestSmoke:
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
) )
for inpt in [make_bounding_box(format="XYXY"), make_detection_and_segmentation_masks()]: for inpt in [make_bounding_box(format="XYXY"), make_masks()]:
output = transform(inpt) output = transform(inpt)
assert output is inpt assert output is inpt
...@@ -278,13 +278,13 @@ class TestRandomHorizontalFlip: ...@@ -278,13 +278,13 @@ class TestRandomHorizontalFlip:
assert_equal(features.Image(expected), actual) assert_equal(features.Image(expected), actual)
def test_features_segmentation_mask(self, p): def test_features_mask(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p) transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(features.SegmentationMask(input)) actual = transform(features.Mask(input))
assert_equal(features.SegmentationMask(expected), actual) assert_equal(features.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
...@@ -331,13 +331,13 @@ class TestRandomVerticalFlip: ...@@ -331,13 +331,13 @@ class TestRandomVerticalFlip:
assert_equal(features.Image(expected), actual) assert_equal(features.Image(expected), actual)
def test_features_segmentation_mask(self, p): def test_features_mask(self, p):
input, expected = self.input_expected_image_tensor(p) input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p) transform = transforms.RandomVerticalFlip(p=p)
actual = transform(features.SegmentationMask(input)) actual = transform(features.Mask(input))
assert_equal(features.SegmentationMask(expected), actual) assert_equal(features.Mask(expected), actual)
def test_features_bounding_box(self, p): def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10)) input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
...@@ -1253,7 +1253,7 @@ class TestRandomIoUCrop: ...@@ -1253,7 +1253,7 @@ class TestRandomIoUCrop:
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area]) torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
output_masks = output[4] output_masks = output[4]
assert isinstance(output_masks, features.SegmentationMask) assert isinstance(output_masks, features.Mask)
assert len(output_masks) == expected_within_targets assert len(output_masks) == expected_within_targets
...@@ -1372,10 +1372,10 @@ class TestSimpleCopyPaste: ...@@ -1372,10 +1372,10 @@ class TestSimpleCopyPaste:
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=features.Label), mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask), mocker.MagicMock(spec=features.Mask),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask), mocker.MagicMock(spec=features.Mask),
] ]
with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"): with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
...@@ -1393,11 +1393,11 @@ class TestSimpleCopyPaste: ...@@ -1393,11 +1393,11 @@ class TestSimpleCopyPaste:
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=label_type), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask), mocker.MagicMock(spec=features.Mask),
# labels, bboxes, masks # labels, bboxes, masks
mocker.MagicMock(spec=label_type), mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox), mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask), mocker.MagicMock(spec=features.Mask),
] ]
images, targets = transform._extract_image_targets(flat_sample) images, targets = transform._extract_image_targets(flat_sample)
...@@ -1413,7 +1413,7 @@ class TestSimpleCopyPaste: ...@@ -1413,7 +1413,7 @@ class TestSimpleCopyPaste:
for target in targets: for target in targets:
for key, type_ in [ for key, type_ in [
("boxes", features.BoundingBox), ("boxes", features.BoundingBox),
("masks", features.SegmentationMask), ("masks", features.Mask),
("labels", label_type), ("labels", label_type),
]: ]:
assert key in target assert key in target
...@@ -1436,7 +1436,7 @@ class TestSimpleCopyPaste: ...@@ -1436,7 +1436,7 @@ class TestSimpleCopyPaste:
"boxes": features.BoundingBox( "boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32) torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
), ),
"masks": features.SegmentationMask(masks), "masks": features.Mask(masks),
"labels": label_type(labels), "labels": label_type(labels),
} }
...@@ -1451,7 +1451,7 @@ class TestSimpleCopyPaste: ...@@ -1451,7 +1451,7 @@ class TestSimpleCopyPaste:
"boxes": features.BoundingBox( "boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32) torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
), ),
"masks": features.SegmentationMask(paste_masks), "masks": features.Mask(paste_masks),
"labels": label_type(paste_labels), "labels": label_type(paste_labels),
} }
...@@ -1586,7 +1586,7 @@ class TestFixedSizeCrop: ...@@ -1586,7 +1586,7 @@ class TestFixedSizeCrop:
bounding_boxes = make_bounding_box( bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,) format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
) )
segmentation_masks = make_detection_mask(size=image_size, extra_dims=(batch_size,)) masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,)) labels = make_label(size=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1)) transform = transforms.FixedSizeCrop((-1, -1))
...@@ -1596,13 +1596,13 @@ class TestFixedSizeCrop: ...@@ -1596,13 +1596,13 @@ class TestFixedSizeCrop:
output = transform( output = transform(
dict( dict(
bounding_boxes=bounding_boxes, bounding_boxes=bounding_boxes,
segmentation_masks=segmentation_masks, masks=masks,
labels=labels, labels=labels,
) )
) )
assert_equal(output["bounding_boxes"], bounding_boxes[is_valid]) assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
assert_equal(output["segmentation_masks"], segmentation_masks[is_valid]) assert_equal(output["masks"], masks[is_valid])
assert_equal(output["labels"], labels[is_valid]) assert_equal(output["labels"], labels[is_valid])
def test__transform_bounding_box_clamping(self, mocker): def test__transform_bounding_box_clamping(self, mocker):
......
...@@ -11,10 +11,10 @@ from common_utils import cpu_and_gpu ...@@ -11,10 +11,10 @@ from common_utils import cpu_and_gpu
from prototype_common_utils import ( from prototype_common_utils import (
ArgsKwargs, ArgsKwargs,
make_bounding_boxes, make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_masks, make_detection_masks,
make_image, make_image,
make_images, make_images,
make_masks,
) )
from torch import jit from torch import jit
from torchvision.prototype import features from torchvision.prototype import features
...@@ -61,8 +61,8 @@ def horizontal_flip_bounding_box(): ...@@ -61,8 +61,8 @@ def horizontal_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask(): def horizontal_flip_mask():
for mask in make_detection_and_segmentation_masks(): for mask in make_masks():
yield ArgsKwargs(mask) yield ArgsKwargs(mask)
...@@ -79,8 +79,8 @@ def vertical_flip_bounding_box(): ...@@ -79,8 +79,8 @@ def vertical_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask(): def vertical_flip_mask():
for mask in make_detection_and_segmentation_masks(): for mask in make_masks():
yield ArgsKwargs(mask) yield ArgsKwargs(mask)
...@@ -123,9 +123,9 @@ def resize_bounding_box(): ...@@ -123,9 +123,9 @@ def resize_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def resize_segmentation_mask(): def resize_mask():
for mask, max_size in itertools.product( for mask, max_size in itertools.product(
make_detection_and_segmentation_masks(), make_masks(),
[None, 34], # max_size [None, 34], # max_size
): ):
height, width = mask.shape[-2:] height, width = mask.shape[-2:]
...@@ -178,9 +178,9 @@ def affine_bounding_box(): ...@@ -178,9 +178,9 @@ def affine_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask(): def affine_mask():
for mask, angle, translate, scale, shear in itertools.product( for mask, angle, translate, scale, shear in itertools.product(
make_detection_and_segmentation_masks(), make_masks(),
[-87, 15, 90], # angle [-87, 15, 90], # angle
[5, -5], # translate [5, -5], # translate
[0.77, 1.27], # scale [0.77, 1.27], # scale
...@@ -231,9 +231,9 @@ def rotate_bounding_box(): ...@@ -231,9 +231,9 @@ def rotate_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask(): def rotate_mask():
for mask, angle, expand, center in itertools.product( for mask, angle, expand, center in itertools.product(
make_detection_and_segmentation_masks(), make_masks(),
[-87, 15, 90], # angle [-87, 15, 90], # angle
[True, False], # expand [True, False], # expand
[None, [12, 23]], # center [None, [12, 23]], # center
...@@ -274,10 +274,8 @@ def crop_bounding_box(): ...@@ -274,10 +274,8 @@ def crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def crop_segmentation_mask(): def crop_mask():
for mask, top, left, height, width in itertools.product( for mask, top, left, height, width in itertools.product(make_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]):
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
):
yield ArgsKwargs( yield ArgsKwargs(
mask, mask,
top=top, top=top,
...@@ -312,9 +310,9 @@ def resized_crop_bounding_box(): ...@@ -312,9 +310,9 @@ def resized_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def resized_crop_segmentation_mask(): def resized_crop_mask():
for mask, top, left, height, width, size in itertools.product( for mask, top, left, height, width, size in itertools.product(
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)] make_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
): ):
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size) yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)
...@@ -331,9 +329,9 @@ def pad_image_tensor(): ...@@ -331,9 +329,9 @@ def pad_image_tensor():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask(): def pad_mask():
for mask, padding, padding_mode in itertools.product( for mask, padding, padding_mode in itertools.product(
make_detection_and_segmentation_masks(), make_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding [[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode, ["constant", "symmetric", "edge", "reflect"], # padding mode,
): ):
...@@ -379,9 +377,9 @@ def perspective_bounding_box(): ...@@ -379,9 +377,9 @@ def perspective_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def perspective_segmentation_mask(): def perspective_mask():
for mask, perspective_coeffs in itertools.product( for mask, perspective_coeffs in itertools.product(
make_detection_and_segmentation_masks(extra_dims=((), (4,))), make_masks(extra_dims=((), (4,))),
[ [
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
...@@ -417,8 +415,8 @@ def elastic_bounding_box(): ...@@ -417,8 +415,8 @@ def elastic_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask(): def elastic_mask():
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))): for mask in make_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:] h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2) displacement = torch.rand(1, h, w, 2)
yield ArgsKwargs( yield ArgsKwargs(
...@@ -445,9 +443,9 @@ def center_crop_bounding_box(): ...@@ -445,9 +443,9 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn @register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask(): def center_crop_mask():
for mask, output_size in itertools.product( for mask, output_size in itertools.product(
make_detection_and_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), make_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
): ):
yield ArgsKwargs(mask, output_size) yield ArgsKwargs(mask, output_size)
...@@ -528,7 +526,7 @@ def erase_image_tensor(): ...@@ -528,7 +526,7 @@ def erase_image_tensor():
for name, kernel in F.__dict__.items() for name, kernel in F.__dict__.items()
if not name.startswith("_") if not name.startswith("_")
and callable(kernel) and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"}) and any(feature_type in name for feature_type in {"image", "mask", "bounding_box", "label"})
and "pil" not in name and "pil" not in name
and name and name
not in { not in {
...@@ -553,9 +551,7 @@ def test_scriptable(kernel): ...@@ -553,9 +551,7 @@ def test_scriptable(kernel):
for name, func in F.__dict__.items() for name, func in F.__dict__.items()
if not name.startswith("_") if not name.startswith("_")
and callable(func) and callable(func)
and all( and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"})
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
)
and name and name
not in { not in {
"to_image_tensor", "to_image_tensor",
...@@ -760,7 +756,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): ...@@ -760,7 +756,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
@pytest.mark.parametrize("scale", [0.89, 1.12]) @pytest.mark.parametrize("scale", [0.89, 1.12])
@pytest.mark.parametrize("shear", [4]) @pytest.mark.parametrize("shear", [4])
@pytest.mark.parametrize("center", [None, (12, 14)]) @pytest.mark.parametrize("center", [None, (12, 14)])
def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, center): def test_correctness_affine_mask(angle, translate, scale, shear, center):
def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_): def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
assert mask.ndim == 3 assert mask.ndim == 3
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_) affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
...@@ -780,7 +776,7 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce ...@@ -780,7 +776,7 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks. # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))): for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.affine_segmentation_mask( output_mask = F.affine_mask(
mask, mask,
angle=angle, angle=angle,
translate=(translate, translate), translate=(translate, translate),
...@@ -824,7 +820,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device): ...@@ -824,7 +820,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device):
expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest") expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest")
expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long() expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long()
out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0]) out_mask = F.affine_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0])
torch.testing.assert_close(out_mask, expected_mask) torch.testing.assert_close(out_mask, expected_mask)
...@@ -976,7 +972,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): ...@@ -976,7 +972,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
@pytest.mark.parametrize("angle", range(-89, 90, 37)) @pytest.mark.parametrize("angle", range(-89, 90, 37))
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))]) @pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
def test_correctness_rotate_segmentation_mask(angle, expand, center): def test_correctness_rotate_mask(angle, expand, center):
def _compute_expected_mask(mask, angle_, expand_, center_): def _compute_expected_mask(mask, angle_, expand_, center_):
assert mask.ndim == 3 assert mask.ndim == 3
c, *image_size = mask.shape c, *image_size = mask.shape
...@@ -1021,7 +1017,7 @@ def test_correctness_rotate_segmentation_mask(angle, expand, center): ...@@ -1021,7 +1017,7 @@ def test_correctness_rotate_segmentation_mask(angle, expand, center):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks. # FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))): for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.rotate_segmentation_mask( output_mask = F.rotate_mask(
mask, mask,
angle=angle, angle=angle,
expand=expand, expand=expand,
...@@ -1060,7 +1056,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device): ...@@ -1060,7 +1056,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
# Rotate 90 degrees # Rotate 90 degrees
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1)) expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False) out_mask = F.rotate_mask(mask, 90, expand=False)
torch.testing.assert_close(out_mask, expected_mask) torch.testing.assert_close(out_mask, expected_mask)
...@@ -1123,7 +1119,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width, ...@@ -1123,7 +1119,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
[-8, -6, 70, 8], [-8, -6, 70, 8],
], ],
) )
def test_correctness_crop_segmentation_mask(device, top, left, height, width): def test_correctness_crop_mask(device, top, left, height, width):
def _compute_expected_mask(mask, top_, left_, height_, width_): def _compute_expected_mask(mask, top_, left_, height_, width_):
h, w = mask.shape[-2], mask.shape[-1] h, w = mask.shape[-2], mask.shape[-1]
if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w: if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w:
...@@ -1147,10 +1143,10 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width): ...@@ -1147,10 +1143,10 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width):
return expected return expected
for mask in make_detection_and_segmentation_masks(): for mask in make_masks():
if mask.device != torch.device(device): if mask.device != torch.device(device):
mask = mask.to(device) mask = mask.to(device)
output_mask = F.crop_segmentation_mask(mask, top, left, height, width) output_mask = F.crop_mask(mask, top, left, height, width)
expected_mask = _compute_expected_mask(mask, top, left, height, width) expected_mask = _compute_expected_mask(mask, top, left, height, width)
torch.testing.assert_close(output_mask, expected_mask) torch.testing.assert_close(output_mask, expected_mask)
...@@ -1160,7 +1156,7 @@ def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): ...@@ -1160,7 +1156,7 @@ def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device):
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
mask[:, :, 0] = 1 mask[:, :, 0] = 1
out_mask = F.horizontal_flip_segmentation_mask(mask) out_mask = F.horizontal_flip_mask(mask)
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, :, -1] = 1 expected_mask[:, :, -1] = 1
...@@ -1172,7 +1168,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): ...@@ -1172,7 +1168,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
mask[:, 0, :] = 1 mask[:, 0, :] = 1
out_mask = F.vertical_flip_segmentation_mask(mask) out_mask = F.vertical_flip_mask(mask)
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, -1, :] = 1 expected_mask[:, -1, :] = 1
...@@ -1233,7 +1229,7 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height ...@@ -1233,7 +1229,7 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
[5, 5, 35, 45, (32, 34)], [5, 5, 35, 45, (32, 34)],
], ],
) )
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size): def test_correctness_resized_crop_mask(device, top, left, height, width, size):
def _compute_expected_mask(mask, top_, left_, height_, width_, size_): def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
output = mask.clone() output = mask.clone()
output = output[:, top_ : top_ + height_, left_ : left_ + width_] output = output[:, top_ : top_ + height_, left_ : left_ + width_]
...@@ -1246,7 +1242,7 @@ def test_correctness_resized_crop_segmentation_mask(device, top, left, height, w ...@@ -1246,7 +1242,7 @@ def test_correctness_resized_crop_segmentation_mask(device, top, left, height, w
in_mask[0, 5:15, 12:23] = 2 in_mask[0, 5:15, 12:23] = 2
expected_mask = _compute_expected_mask(in_mask, top, left, height, width, size) expected_mask = _compute_expected_mask(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size) output_mask = F.resized_crop_mask(in_mask, top, left, height, width, size)
torch.testing.assert_close(output_mask, expected_mask) torch.testing.assert_close(output_mask, expected_mask)
...@@ -1310,7 +1306,7 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -1310,7 +1306,7 @@ def test_correctness_pad_bounding_box(device, padding):
def test_correctness_pad_segmentation_mask_on_fixed_input(device): def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device) mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1]) out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device) expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
expected_mask[:, 1:-1, 1:-1] = 1 expected_mask[:, 1:-1, 1:-1] = 1
...@@ -1319,7 +1315,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device): ...@@ -1319,7 +1315,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]]) @pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
@pytest.mark.parametrize("padding_mode", ["constant", "edge", "reflect", "symmetric"]) @pytest.mark.parametrize("padding_mode", ["constant", "edge", "reflect", "symmetric"])
def test_correctness_pad_segmentation_mask(padding, padding_mode): def test_correctness_pad_mask(padding, padding_mode):
def _compute_expected_mask(mask, padding_, padding_mode_): def _compute_expected_mask(mask, padding_, padding_mode_):
h, w = mask.shape[-2], mask.shape[-1] h, w = mask.shape[-2], mask.shape[-1]
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_) pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
...@@ -1367,8 +1363,8 @@ def test_correctness_pad_segmentation_mask(padding, padding_mode): ...@@ -1367,8 +1363,8 @@ def test_correctness_pad_segmentation_mask(padding, padding_mode):
return output return output
for mask in make_detection_and_segmentation_masks(): for mask in make_masks():
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode) out_mask = F.pad_mask(mask, padding, padding_mode=padding_mode)
expected_mask = _compute_expected_mask(mask, padding, padding_mode) expected_mask = _compute_expected_mask(mask, padding, padding_mode)
torch.testing.assert_close(out_mask, expected_mask) torch.testing.assert_close(out_mask, expected_mask)
...@@ -1473,7 +1469,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -1473,7 +1469,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]], [[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
], ],
) )
def test_correctness_perspective_segmentation_mask(device, startpoints, endpoints): def test_correctness_perspective_mask(device, startpoints, endpoints):
def _compute_expected_mask(mask, pcoeffs_): def _compute_expected_mask(mask, pcoeffs_):
assert mask.ndim == 3 assert mask.ndim == 3
m1 = np.array([[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]]]) m1 = np.array([[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]]])
...@@ -1500,7 +1496,7 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint ...@@ -1500,7 +1496,7 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint
for mask in make_detection_masks(extra_dims=((), (4,))): for mask in make_detection_masks(extra_dims=((), (4,))):
mask = mask.to(device) mask = mask.to(device)
output_mask = F.perspective_segmentation_mask( output_mask = F.perspective_mask(
mask, mask,
perspective_coeffs=pcoeffs, perspective_coeffs=pcoeffs,
) )
...@@ -1579,8 +1575,8 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -1579,8 +1575,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
@pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]]) @pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
def test_correctness_center_crop_segmentation_mask(device, output_size): def test_correctness_center_crop_mask(device, output_size):
def _compute_expected_segmentation_mask(mask, output_size): def _compute_expected_mask(mask, output_size):
crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]] crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]
_, image_height, image_width = mask.shape _, image_height, image_width = mask.shape
...@@ -1594,9 +1590,9 @@ def test_correctness_center_crop_segmentation_mask(device, output_size): ...@@ -1594,9 +1590,9 @@ def test_correctness_center_crop_segmentation_mask(device, output_size):
return mask[:, top : top + crop_height, left : left + crop_width] return mask[:, top : top + crop_height, left : left + crop_width]
mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device) mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
actual = F.center_crop_segmentation_mask(mask, output_size) actual = F.center_crop_mask(mask, output_size)
expected = _compute_expected_segmentation_mask(mask, output_size) expected = _compute_expected_mask(mask, output_size)
torch.testing.assert_close(expected, actual) torch.testing.assert_close(expected, actual)
...@@ -1663,7 +1659,7 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s ...@@ -1663,7 +1659,7 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
[ [
(F.elastic_image_tensor, make_images), (F.elastic_image_tensor, make_images),
# FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks. # FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks.
(F.elastic_segmentation_mask, make_detection_masks), (F.elastic_mask, make_detection_masks),
], ],
) )
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
...@@ -1681,7 +1677,7 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): ...@@ -1681,7 +1677,7 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
sample = features.Image(sample) sample = features.Image(sample)
kwargs = {"interpolation": F.InterpolationMode.NEAREST} kwargs = {"interpolation": F.InterpolationMode.NEAREST}
else: else:
sample = features.SegmentationMask(sample) sample = features.Mask(sample)
kwargs = {} kwargs = {}
# Create a displacement grid using sin # Create a displacement grid using sin
......
...@@ -12,30 +12,30 @@ from torchvision.prototype.transforms.functional import to_image_pil ...@@ -12,30 +12,30 @@ from torchvision.prototype.transforms.functional import to_image_pil
IMAGE = make_image(color_space=features.ColorSpace.RGB) IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size) BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_detection_mask(size=IMAGE.image_size) MASK = make_detection_mask(size=IMAGE.image_size)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True),
((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False), ((MASK,), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False), ((BOUNDING_BOX,), (features.Image, features.Mask), False),
((IMAGE,), (features.BoundingBox, features.SegmentationMask), False), ((IMAGE,), (features.BoundingBox, features.Mask), False),
( (
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.SegmentationMask), (features.Image, features.BoundingBox, features.Mask),
True, True,
), ),
((), (features.Image, features.BoundingBox, features.SegmentationMask), False), ((), (features.Image, features.BoundingBox, features.Mask), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), ((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), ((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True), ((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
...@@ -48,35 +48,35 @@ def test_has_any(sample, types, expected): ...@@ -48,35 +48,35 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
("sample", "types", "expected"), ("sample", "types", "expected"),
[ [
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True), ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True), ((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True), ((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True),
( (
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.SegmentationMask), (features.Image, features.BoundingBox, features.Mask),
True, True,
), ),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False), ((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False), ((BOUNDING_BOX, MASK), (features.Image, features.Mask), False),
((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False), ((IMAGE, MASK), (features.BoundingBox, features.Mask), False),
( (
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.SegmentationMask), (features.Image, features.BoundingBox, features.Mask),
True, True,
), ),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), ((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox, features.Mask), False),
((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False), ((IMAGE, MASK), (features.Image, features.BoundingBox, features.Mask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False), ((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.Mask), False),
( (
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (IMAGE, BOUNDING_BOX, MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),), (lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.Mask)),),
True, True,
), ),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True), ((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
], ],
) )
def test_has_all(sample, types, expected): def test_has_all(sample, types, expected):
......
...@@ -3,4 +3,4 @@ from ._encoded import EncodedData, EncodedImage, EncodedVideo ...@@ -3,4 +3,4 @@ from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, is_simple_tensor from ._feature import _Feature, is_simple_tensor
from ._image import ColorSpace, Image from ._image import ColorSpace, Image
from ._label import Label, OneHotLabel from ._label import Label, OneHotLabel
from ._segmentation_mask import SegmentationMask from ._mask import Mask
...@@ -8,14 +8,14 @@ from torchvision.transforms import InterpolationMode ...@@ -8,14 +8,14 @@ from torchvision.transforms import InterpolationMode
from ._feature import _Feature from ._feature import _Feature
class SegmentationMask(_Feature): class Mask(_Feature):
def horizontal_flip(self) -> SegmentationMask: def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_segmentation_mask(self) output = self._F.horizontal_flip_mask(self)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def vertical_flip(self) -> SegmentationMask: def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_segmentation_mask(self) output = self._F.vertical_flip_mask(self)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def resize( # type: ignore[override] def resize( # type: ignore[override]
self, self,
...@@ -23,17 +23,17 @@ class SegmentationMask(_Feature): ...@@ -23,17 +23,17 @@ class SegmentationMask(_Feature):
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
max_size: Optional[int] = None, max_size: Optional[int] = None,
antialias: bool = False, antialias: bool = False,
) -> SegmentationMask: ) -> Mask:
output = self._F.resize_segmentation_mask(self, size, max_size=max_size) output = self._F.resize_mask(self, size, max_size=max_size)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> SegmentationMask: def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_segmentation_mask(self, top, left, height, width) output = self._F.crop_mask(self, top, left, height, width)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def center_crop(self, output_size: List[int]) -> SegmentationMask: def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_segmentation_mask(self, output_size=output_size) output = self._F.center_crop_mask(self, output_size=output_size)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def resized_crop( def resized_crop(
self, self,
...@@ -44,22 +44,22 @@ class SegmentationMask(_Feature): ...@@ -44,22 +44,22 @@ class SegmentationMask(_Feature):
size: List[int], size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
antialias: bool = False, antialias: bool = False,
) -> SegmentationMask: ) -> Mask:
output = self._F.resized_crop_segmentation_mask(self, top, left, height, width, size=size) output = self._F.resized_crop_mask(self, top, left, height, width, size=size)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def pad( def pad(
self, self,
padding: Union[int, Sequence[int]], padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant", padding_mode: str = "constant",
) -> SegmentationMask: ) -> Mask:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy # This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int): if not isinstance(padding, int):
padding = list(padding) padding = list(padding)
output = self._F.pad_segmentation_mask(self, padding, padding_mode=padding_mode) output = self._F.pad_mask(self, padding, padding_mode=padding_mode)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def rotate( def rotate(
self, self,
...@@ -68,9 +68,9 @@ class SegmentationMask(_Feature): ...@@ -68,9 +68,9 @@ class SegmentationMask(_Feature):
expand: bool = False, expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> SegmentationMask: ) -> Mask:
output = self._F.rotate_segmentation_mask(self, angle, expand=expand, center=center) output = self._F.rotate_mask(self, angle, expand=expand, center=center)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def affine( def affine(
self, self,
...@@ -81,8 +81,8 @@ class SegmentationMask(_Feature): ...@@ -81,8 +81,8 @@ class SegmentationMask(_Feature):
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> SegmentationMask: ) -> Mask:
output = self._F.affine_segmentation_mask( output = self._F.affine_mask(
self, self,
angle, angle,
translate=translate, translate=translate,
...@@ -90,22 +90,22 @@ class SegmentationMask(_Feature): ...@@ -90,22 +90,22 @@ class SegmentationMask(_Feature):
shear=shear, shear=shear,
center=center, center=center,
) )
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def perspective( def perspective(
self, self,
perspective_coeffs: List[float], perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask: ) -> Mask:
output = self._F.perspective_segmentation_mask(self, perspective_coeffs) output = self._F.perspective_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output) return Mask.new_like(self, output)
def elastic( def elastic(
self, self,
displacement: torch.Tensor, displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST, interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask: ) -> Mask:
output = self._F.elastic_segmentation_mask(self, displacement) output = self._F.elastic_mask(self, displacement)
return SegmentationMask.new_like(self, output, dtype=output.dtype) return Mask.new_like(self, output, dtype=output.dtype)
...@@ -108,10 +108,8 @@ class _BaseMixupCutmix(_RandomApplyTransform): ...@@ -108,10 +108,8 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)): if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, features.BoundingBox, features.SegmentationMask, features.Label): if has_any(inputs, features.BoundingBox, features.Mask, features.Label):
raise TypeError( raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.")
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."
)
return super().forward(*inputs) return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
...@@ -280,7 +278,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -280,7 +278,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]: def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input # fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label] # with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], [] images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample: for obj in flat_sample:
if isinstance(obj, features.Image) or features.is_simple_tensor(obj): if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
...@@ -289,7 +287,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -289,7 +287,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
images.append(F.to_image_tensor(obj)) images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox): elif isinstance(obj, features.BoundingBox):
bboxes.append(obj) bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask): elif isinstance(obj, features.Mask):
masks.append(obj) masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)): elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj) labels.append(obj)
...@@ -297,7 +295,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -297,7 +295,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
if not (len(images) == len(bboxes) == len(masks) == len(labels)): if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, " f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels." "BoundingBoxes, Masks and Labels or OneHotLabels."
) )
targets = [] targets = []
...@@ -323,8 +321,8 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -323,8 +321,8 @@ class SimpleCopyPaste(_RandomApplyTransform):
elif isinstance(obj, features.BoundingBox): elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"]) flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
c1 += 1 c1 += 1
elif isinstance(obj, features.SegmentationMask): elif isinstance(obj, features.Mask):
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"]) flat_sample[i] = features.Mask.new_like(obj, output_targets[c2]["masks"])
c2 += 1 c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)): elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type] flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
......
...@@ -38,7 +38,7 @@ class _AutoAugmentBase(Transform): ...@@ -38,7 +38,7 @@ class _AutoAugmentBase(Transform):
def _extract_image( def _extract_image(
self, self,
sample: Any, sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask), unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]: ) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample) sample_flat, _ = tree_flatten(sample)
images = [] images = []
......
...@@ -170,8 +170,8 @@ class FiveCrop(Transform): ...@@ -170,8 +170,8 @@ class FiveCrop(Transform):
return F.five_crop(inpt, self.size) return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.SegmentationMask): if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs) return super().forward(*inputs)
...@@ -191,8 +191,8 @@ class TenCrop(Transform): ...@@ -191,8 +191,8 @@ class TenCrop(Transform):
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any: def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.SegmentationMask): if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs) return super().forward(*inputs)
...@@ -690,10 +690,10 @@ class RandomIoUCrop(Transform): ...@@ -690,10 +690,10 @@ class RandomIoUCrop(Transform):
bboxes = output[is_within_crop_area] bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size) bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size)
output = features.BoundingBox.new_like(output, bboxes) output = features.BoundingBox.new_like(output, bboxes)
elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1: elif isinstance(output, features.Mask) and output.shape[-3] > 1:
# apply is_within_crop_area if mask is one-hot encoded # apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area] masks = output[is_within_crop_area]
output = features.SegmentationMask.new_like(output, masks) output = features.Mask.new_like(output, masks)
return output return output
...@@ -705,7 +705,7 @@ class RandomIoUCrop(Transform): ...@@ -705,7 +705,7 @@ class RandomIoUCrop(Transform):
): ):
raise TypeError( raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, " f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks." "BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
) )
return super().forward(*inputs) return super().forward(*inputs)
...@@ -842,7 +842,7 @@ class FixedSizeCrop(Transform): ...@@ -842,7 +842,7 @@ class FixedSizeCrop(Transform):
) )
if params["is_valid"] is not None: if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)): if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type] inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox): elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.new_like( inpt = features.BoundingBox.new_like(
......
...@@ -150,7 +150,7 @@ class ToDtype(Lambda): ...@@ -150,7 +150,7 @@ class ToDtype(Lambda):
class RemoveSmallBoundingBoxes(Transform): class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.SegmentationMask, features.Label, features.OneHotLabel) _transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)
def __init__(self, min_size: float = 1.0) -> None: def __init__(self, min_size: float = 1.0) -> None:
super().__init__() super().__init__()
......
...@@ -53,22 +53,22 @@ from ._geometry import ( ...@@ -53,22 +53,22 @@ from ._geometry import (
affine_bounding_box, affine_bounding_box,
affine_image_pil, affine_image_pil,
affine_image_tensor, affine_image_tensor,
affine_segmentation_mask, affine_mask,
center_crop, center_crop,
center_crop_bounding_box, center_crop_bounding_box,
center_crop_image_pil, center_crop_image_pil,
center_crop_image_tensor, center_crop_image_tensor,
center_crop_segmentation_mask, center_crop_mask,
crop, crop,
crop_bounding_box, crop_bounding_box,
crop_image_pil, crop_image_pil,
crop_image_tensor, crop_image_tensor,
crop_segmentation_mask, crop_mask,
elastic, elastic,
elastic_bounding_box, elastic_bounding_box,
elastic_image_pil, elastic_image_pil,
elastic_image_tensor, elastic_image_tensor,
elastic_segmentation_mask, elastic_mask,
elastic_transform, elastic_transform,
five_crop, five_crop,
five_crop_image_pil, five_crop_image_pil,
...@@ -78,32 +78,32 @@ from ._geometry import ( ...@@ -78,32 +78,32 @@ from ._geometry import (
horizontal_flip_bounding_box, horizontal_flip_bounding_box,
horizontal_flip_image_pil, horizontal_flip_image_pil,
horizontal_flip_image_tensor, horizontal_flip_image_tensor,
horizontal_flip_segmentation_mask, horizontal_flip_mask,
pad, pad,
pad_bounding_box, pad_bounding_box,
pad_image_pil, pad_image_pil,
pad_image_tensor, pad_image_tensor,
pad_segmentation_mask, pad_mask,
perspective, perspective,
perspective_bounding_box, perspective_bounding_box,
perspective_image_pil, perspective_image_pil,
perspective_image_tensor, perspective_image_tensor,
perspective_segmentation_mask, perspective_mask,
resize, resize,
resize_bounding_box, resize_bounding_box,
resize_image_pil, resize_image_pil,
resize_image_tensor, resize_image_tensor,
resize_segmentation_mask, resize_mask,
resized_crop, resized_crop,
resized_crop_bounding_box, resized_crop_bounding_box,
resized_crop_image_pil, resized_crop_image_pil,
resized_crop_image_tensor, resized_crop_image_tensor,
resized_crop_segmentation_mask, resized_crop_mask,
rotate, rotate,
rotate_bounding_box, rotate_bounding_box,
rotate_image_pil, rotate_image_pil,
rotate_image_tensor, rotate_image_tensor,
rotate_segmentation_mask, rotate_mask,
ten_crop, ten_crop,
ten_crop_image_pil, ten_crop_image_pil,
ten_crop_image_tensor, ten_crop_image_tensor,
...@@ -111,7 +111,7 @@ from ._geometry import ( ...@@ -111,7 +111,7 @@ from ._geometry import (
vertical_flip_bounding_box, vertical_flip_bounding_box,
vertical_flip_image_pil, vertical_flip_image_pil,
vertical_flip_image_tensor, vertical_flip_image_tensor,
vertical_flip_segmentation_mask, vertical_flip_mask,
vflip, vflip,
) )
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
......
...@@ -28,8 +28,8 @@ horizontal_flip_image_tensor = _FT.hflip ...@@ -28,8 +28,8 @@ horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip horizontal_flip_image_pil = _FP.hflip
def horizontal_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(segmentation_mask) return horizontal_flip_image_tensor(mask)
def horizontal_flip_bounding_box( def horizontal_flip_bounding_box(
...@@ -61,8 +61,8 @@ vertical_flip_image_tensor = _FT.vflip ...@@ -61,8 +61,8 @@ vertical_flip_image_tensor = _FT.vflip
vertical_flip_image_pil = _FP.vflip vertical_flip_image_pil = _FP.vflip
def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor: def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(segmentation_mask) return vertical_flip_image_tensor(mask)
def vertical_flip_bounding_box( def vertical_flip_bounding_box(
...@@ -132,18 +132,14 @@ def resize_image_pil( ...@@ -132,18 +132,14 @@ def resize_image_pil(
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation]) return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation])
def resize_segmentation_mask( def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
segmentation_mask: torch.Tensor, size: List[int], max_size: Optional[int] = None if mask.ndim < 3:
) -> torch.Tensor: mask = mask.unsqueeze(0)
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = resize_image_tensor( output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size
)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -379,22 +375,22 @@ def affine_bounding_box( ...@@ -379,22 +375,22 @@ def affine_bounding_box(
).view(original_shape) ).view(original_shape)
def affine_segmentation_mask( def affine_mask(
segmentation_mask: torch.Tensor, mask: torch.Tensor,
angle: float, angle: float,
translate: List[float], translate: List[float],
scale: float, scale: float,
shear: List[float], shear: List[float],
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if segmentation_mask.ndim < 3: if mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = affine_image_tensor( output = affine_image_tensor(
segmentation_mask, mask,
angle=angle, angle=angle,
translate=translate, translate=translate,
scale=scale, scale=scale,
...@@ -545,20 +541,20 @@ def rotate_bounding_box( ...@@ -545,20 +541,20 @@ def rotate_bounding_box(
).view(original_shape) ).view(original_shape)
def rotate_segmentation_mask( def rotate_mask(
segmentation_mask: torch.Tensor, mask: torch.Tensor,
angle: float, angle: float,
expand: bool = False, expand: bool = False,
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
if segmentation_mask.ndim < 3: if mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = rotate_image_tensor( output = rotate_image_tensor(
segmentation_mask, mask,
angle=angle, angle=angle,
expand=expand, expand=expand,
interpolation=InterpolationMode.NEAREST, interpolation=InterpolationMode.NEAREST,
...@@ -639,16 +635,14 @@ def _pad_with_vector_fill( ...@@ -639,16 +635,14 @@ def _pad_with_vector_fill(
return output return output
def pad_segmentation_mask( def pad_mask(mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant") -> torch.Tensor:
segmentation_mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant" if mask.ndim < 3:
) -> torch.Tensor: mask = mask.unsqueeze(0)
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode) output = pad_image_tensor(img=mask, padding=padding, fill=0, padding_mode=padding_mode)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -723,8 +717,8 @@ def crop_bounding_box( ...@@ -723,8 +717,8 @@ def crop_bounding_box(
) )
def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor: def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(img, top, left, height, width) return crop_image_tensor(mask, top, left, height, width)
def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType: def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType:
...@@ -839,15 +833,15 @@ def perspective_bounding_box( ...@@ -839,15 +833,15 @@ def perspective_bounding_box(
).view(original_shape) ).view(original_shape)
def perspective_segmentation_mask(segmentation_mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
if segmentation_mask.ndim < 3: if mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = perspective_image_tensor( output = perspective_image_tensor(
segmentation_mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
) )
if needs_squeeze: if needs_squeeze:
...@@ -937,14 +931,14 @@ def elastic_bounding_box( ...@@ -937,14 +931,14 @@ def elastic_bounding_box(
).view(original_shape) ).view(original_shape)
def elastic_segmentation_mask(segmentation_mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor: def elastic_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
if segmentation_mask.ndim < 3: if mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = elastic_image_tensor(segmentation_mask, displacement=displacement, interpolation=InterpolationMode.NEAREST) output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -1040,14 +1034,14 @@ def center_crop_bounding_box( ...@@ -1040,14 +1034,14 @@ def center_crop_bounding_box(
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left) return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)
def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor: def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if segmentation_mask.ndim < 3: if mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0) mask = mask.unsqueeze(0)
needs_squeeze = True needs_squeeze = True
else: else:
needs_squeeze = False needs_squeeze = False
output = center_crop_image_tensor(img=segmentation_mask, output_size=output_size) output = center_crop_image_tensor(img=mask, output_size=output_size)
if needs_squeeze: if needs_squeeze:
output = output.squeeze(0) output = output.squeeze(0)
...@@ -1104,7 +1098,7 @@ def resized_crop_bounding_box( ...@@ -1104,7 +1098,7 @@ def resized_crop_bounding_box(
return resize_bounding_box(bounding_box, size, (height, width)) return resize_bounding_box(bounding_box, size, (height, width))
def resized_crop_segmentation_mask( def resized_crop_mask(
mask: torch.Tensor, mask: torch.Tensor,
top: int, top: int,
left: int, left: int,
...@@ -1112,8 +1106,8 @@ def resized_crop_segmentation_mask( ...@@ -1112,8 +1106,8 @@ def resized_crop_segmentation_mask(
width: int, width: int,
size: List[int], size: List[int],
) -> torch.Tensor: ) -> torch.Tensor:
mask = crop_segmentation_mask(mask, top, left, height, width) mask = crop_mask(mask, top, left, height, width)
return resize_segmentation_mask(mask, size) return resize_mask(mask, size)
def resized_crop( def resized_crop(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment