Unverified Commit 1ea73f58 authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Rename features.SegmentationMask to features.Mask (#6579)

* rename features.SegmentationMask -> features.Mask

* rename kernels *_segmentation_mask -> *_mask and cleanup input name

* cleanup

* rename module _segmentation_mask.py -> _mask.py

* fix test
parent f007a5e1
......@@ -184,7 +184,7 @@ def make_detection_mask(size=None, *, num_objects=None, extra_dims=(), dtype=tor
num_objects = num_objects if num_objects is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, num_objects, *size)
data = make_tensor(shape, low=0, high=2, dtype=dtype)
return features.SegmentationMask(data)
return features.Mask(data)
def make_detection_masks(
......@@ -207,7 +207,7 @@ def make_segmentation_mask(size=None, *, num_categories=None, extra_dims=(), dty
num_categories = num_categories if num_categories is not None else int(torch.randint(1, 11, ()))
shape = (*extra_dims, *size)
data = make_tensor(shape, low=0, high=num_categories, dtype=dtype)
return features.SegmentationMask(data)
return features.Mask(data)
def make_segmentation_masks(
......@@ -224,7 +224,7 @@ def make_segmentation_masks(
yield make_segmentation_mask(size=sizes[0], num_categories=num_categories_, dtype=dtype, extra_dims=extra_dims_)
def make_detection_and_segmentation_masks(
def make_masks(
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.uint8,),
extra_dims=((), (0,), (4,), (2, 3), (5, 0), (0, 5)),
......
......@@ -10,11 +10,11 @@ from common_utils import assert_equal, cpu_and_gpu
from prototype_common_utils import (
make_bounding_box,
make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_mask,
make_image,
make_images,
make_label,
make_masks,
make_one_hot_labels,
make_segmentation_mask,
)
......@@ -64,7 +64,7 @@ def parametrize_from_transforms(*transforms):
make_one_hot_labels,
make_vanilla_tensor_images,
make_pil_images,
make_detection_and_segmentation_masks,
make_masks,
]:
inputs = list(creation_fn())
try:
......@@ -132,7 +132,7 @@ class TestSmoke:
transform(input_copy)
# Check if we raise an error if sample contains bbox or mask or label
err_msg = "does not support bounding boxes, segmentation masks and plain labels"
err_msg = "does not support bounding boxes, masks and plain labels"
input_copy = dict(input)
for unsup_data in [
make_label(),
......@@ -241,7 +241,7 @@ class TestSmoke:
color_space=features.ColorSpace.RGB, old_color_space=features.ColorSpace.GRAY
)
for inpt in [make_bounding_box(format="XYXY"), make_detection_and_segmentation_masks()]:
for inpt in [make_bounding_box(format="XYXY"), make_masks()]:
output = transform(inpt)
assert output is inpt
......@@ -278,13 +278,13 @@ class TestRandomHorizontalFlip:
assert_equal(features.Image(expected), actual)
def test_features_segmentation_mask(self, p):
def test_features_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomHorizontalFlip(p=p)
actual = transform(features.SegmentationMask(input))
actual = transform(features.Mask(input))
assert_equal(features.SegmentationMask(expected), actual)
assert_equal(features.Mask(expected), actual)
def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
......@@ -331,13 +331,13 @@ class TestRandomVerticalFlip:
assert_equal(features.Image(expected), actual)
def test_features_segmentation_mask(self, p):
def test_features_mask(self, p):
input, expected = self.input_expected_image_tensor(p)
transform = transforms.RandomVerticalFlip(p=p)
actual = transform(features.SegmentationMask(input))
actual = transform(features.Mask(input))
assert_equal(features.SegmentationMask(expected), actual)
assert_equal(features.Mask(expected), actual)
def test_features_bounding_box(self, p):
input = features.BoundingBox([0, 0, 5, 5], format=features.BoundingBoxFormat.XYXY, image_size=(10, 10))
......@@ -1253,7 +1253,7 @@ class TestRandomIoUCrop:
torch.testing.assert_close(output_ohe_label, ohe_label[is_within_crop_area])
output_masks = output[4]
assert isinstance(output_masks, features.SegmentationMask)
assert isinstance(output_masks, features.Mask)
assert len(output_masks) == expected_within_targets
......@@ -1372,10 +1372,10 @@ class TestSimpleCopyPaste:
# labels, bboxes, masks
mocker.MagicMock(spec=features.Label),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
mocker.MagicMock(spec=features.Mask),
# labels, bboxes, masks
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
mocker.MagicMock(spec=features.Mask),
]
with pytest.raises(TypeError, match="requires input sample to contain equal sized list of Images"):
......@@ -1393,11 +1393,11 @@ class TestSimpleCopyPaste:
# labels, bboxes, masks
mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
mocker.MagicMock(spec=features.Mask),
# labels, bboxes, masks
mocker.MagicMock(spec=label_type),
mocker.MagicMock(spec=features.BoundingBox),
mocker.MagicMock(spec=features.SegmentationMask),
mocker.MagicMock(spec=features.Mask),
]
images, targets = transform._extract_image_targets(flat_sample)
......@@ -1413,7 +1413,7 @@ class TestSimpleCopyPaste:
for target in targets:
for key, type_ in [
("boxes", features.BoundingBox),
("masks", features.SegmentationMask),
("masks", features.Mask),
("labels", label_type),
]:
assert key in target
......@@ -1436,7 +1436,7 @@ class TestSimpleCopyPaste:
"boxes": features.BoundingBox(
torch.tensor([[2.0, 3.0, 8.0, 9.0], [20.0, 20.0, 30.0, 30.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(masks),
"masks": features.Mask(masks),
"labels": label_type(labels),
}
......@@ -1451,7 +1451,7 @@ class TestSimpleCopyPaste:
"boxes": features.BoundingBox(
torch.tensor([[12.0, 13.0, 19.0, 18.0], [1.0, 15.0, 8.0, 19.0]]), format="XYXY", image_size=(32, 32)
),
"masks": features.SegmentationMask(paste_masks),
"masks": features.Mask(paste_masks),
"labels": label_type(paste_labels),
}
......@@ -1586,7 +1586,7 @@ class TestFixedSizeCrop:
bounding_boxes = make_bounding_box(
format=features.BoundingBoxFormat.XYXY, image_size=image_size, extra_dims=(batch_size,)
)
segmentation_masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
masks = make_detection_mask(size=image_size, extra_dims=(batch_size,))
labels = make_label(size=(batch_size,))
transform = transforms.FixedSizeCrop((-1, -1))
......@@ -1596,13 +1596,13 @@ class TestFixedSizeCrop:
output = transform(
dict(
bounding_boxes=bounding_boxes,
segmentation_masks=segmentation_masks,
masks=masks,
labels=labels,
)
)
assert_equal(output["bounding_boxes"], bounding_boxes[is_valid])
assert_equal(output["segmentation_masks"], segmentation_masks[is_valid])
assert_equal(output["masks"], masks[is_valid])
assert_equal(output["labels"], labels[is_valid])
def test__transform_bounding_box_clamping(self, mocker):
......
......@@ -11,10 +11,10 @@ from common_utils import cpu_and_gpu
from prototype_common_utils import (
ArgsKwargs,
make_bounding_boxes,
make_detection_and_segmentation_masks,
make_detection_masks,
make_image,
make_images,
make_masks,
)
from torch import jit
from torchvision.prototype import features
......@@ -61,8 +61,8 @@ def horizontal_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def horizontal_flip_segmentation_mask():
for mask in make_detection_and_segmentation_masks():
def horizontal_flip_mask():
for mask in make_masks():
yield ArgsKwargs(mask)
......@@ -79,8 +79,8 @@ def vertical_flip_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def vertical_flip_segmentation_mask():
for mask in make_detection_and_segmentation_masks():
def vertical_flip_mask():
for mask in make_masks():
yield ArgsKwargs(mask)
......@@ -123,9 +123,9 @@ def resize_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def resize_segmentation_mask():
def resize_mask():
for mask, max_size in itertools.product(
make_detection_and_segmentation_masks(),
make_masks(),
[None, 34], # max_size
):
height, width = mask.shape[-2:]
......@@ -178,9 +178,9 @@ def affine_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def affine_segmentation_mask():
def affine_mask():
for mask, angle, translate, scale, shear in itertools.product(
make_detection_and_segmentation_masks(),
make_masks(),
[-87, 15, 90], # angle
[5, -5], # translate
[0.77, 1.27], # scale
......@@ -231,9 +231,9 @@ def rotate_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def rotate_segmentation_mask():
def rotate_mask():
for mask, angle, expand, center in itertools.product(
make_detection_and_segmentation_masks(),
make_masks(),
[-87, 15, 90], # angle
[True, False], # expand
[None, [12, 23]], # center
......@@ -274,10 +274,8 @@ def crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def crop_segmentation_mask():
for mask, top, left, height, width in itertools.product(
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]
):
def crop_mask():
for mask, top, left, height, width in itertools.product(make_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20]):
yield ArgsKwargs(
mask,
top=top,
......@@ -312,9 +310,9 @@ def resized_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def resized_crop_segmentation_mask():
def resized_crop_mask():
for mask, top, left, height, width, size in itertools.product(
make_detection_and_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
make_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
):
yield ArgsKwargs(mask, top=top, left=left, height=height, width=width, size=size)
......@@ -331,9 +329,9 @@ def pad_image_tensor():
@register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask():
def pad_mask():
for mask, padding, padding_mode in itertools.product(
make_detection_and_segmentation_masks(),
make_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
......@@ -379,9 +377,9 @@ def perspective_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def perspective_segmentation_mask():
def perspective_mask():
for mask, perspective_coeffs in itertools.product(
make_detection_and_segmentation_masks(extra_dims=((), (4,))),
make_masks(extra_dims=((), (4,))),
[
[1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018],
[0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063],
......@@ -417,8 +415,8 @@ def elastic_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask():
for mask in make_detection_and_segmentation_masks(extra_dims=((), (4,))):
def elastic_mask():
for mask in make_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield ArgsKwargs(
......@@ -445,9 +443,9 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask():
def center_crop_mask():
for mask, output_size in itertools.product(
make_detection_and_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
make_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield ArgsKwargs(mask, output_size)
......@@ -528,7 +526,7 @@ def erase_image_tensor():
for name, kernel in F.__dict__.items()
if not name.startswith("_")
and callable(kernel)
and any(feature_type in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label"})
and any(feature_type in name for feature_type in {"image", "mask", "bounding_box", "label"})
and "pil" not in name
and name
not in {
......@@ -553,9 +551,7 @@ def test_scriptable(kernel):
for name, func in F.__dict__.items()
if not name.startswith("_")
and callable(func)
and all(
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
)
and all(feature_type not in name for feature_type in {"image", "mask", "bounding_box", "label", "pil"})
and name
not in {
"to_image_tensor",
......@@ -760,7 +756,7 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
@pytest.mark.parametrize("scale", [0.89, 1.12])
@pytest.mark.parametrize("shear", [4])
@pytest.mark.parametrize("center", [None, (12, 14)])
def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, center):
def test_correctness_affine_mask(angle, translate, scale, shear, center):
def _compute_expected_mask(mask, angle_, translate_, scale_, shear_, center_):
assert mask.ndim == 3
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
......@@ -780,7 +776,7 @@ def test_correctness_affine_segmentation_mask(angle, translate, scale, shear, ce
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.affine_segmentation_mask(
output_mask = F.affine_mask(
mask,
angle=angle,
translate=(translate, translate),
......@@ -824,7 +820,7 @@ def test_correctness_affine_segmentation_mask_on_fixed_input(device):
expected_mask = torch.nn.functional.interpolate(expected_mask[None, :].float(), size=(64, 64), mode="nearest")
expected_mask = expected_mask[0, :, 16 : 64 - 16, 16 : 64 - 16].long()
out_mask = F.affine_segmentation_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0])
out_mask = F.affine_mask(mask, 90, [0.0, 0.0], 64.0 / 32.0, [0.0, 0.0])
torch.testing.assert_close(out_mask, expected_mask)
......@@ -976,7 +972,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
@pytest.mark.parametrize("angle", range(-89, 90, 37))
@pytest.mark.parametrize("expand, center", [(True, None), (False, None), (False, (12, 14))])
def test_correctness_rotate_segmentation_mask(angle, expand, center):
def test_correctness_rotate_mask(angle, expand, center):
def _compute_expected_mask(mask, angle_, expand_, center_):
assert mask.ndim == 3
c, *image_size = mask.shape
......@@ -1021,7 +1017,7 @@ def test_correctness_rotate_segmentation_mask(angle, expand, center):
# FIXME: `_compute_expected_mask` currently only works for "detection" masks. Extend it for "segmentation" masks.
for mask in make_detection_masks(extra_dims=((), (4,))):
output_mask = F.rotate_segmentation_mask(
output_mask = F.rotate_mask(
mask,
angle=angle,
expand=expand,
......@@ -1060,7 +1056,7 @@ def test_correctness_rotate_segmentation_mask_on_fixed_input(device):
# Rotate 90 degrees
expected_mask = torch.rot90(mask, k=1, dims=(-2, -1))
out_mask = F.rotate_segmentation_mask(mask, 90, expand=False)
out_mask = F.rotate_mask(mask, 90, expand=False)
torch.testing.assert_close(out_mask, expected_mask)
......@@ -1123,7 +1119,7 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
[-8, -6, 70, 8],
],
)
def test_correctness_crop_segmentation_mask(device, top, left, height, width):
def test_correctness_crop_mask(device, top, left, height, width):
def _compute_expected_mask(mask, top_, left_, height_, width_):
h, w = mask.shape[-2], mask.shape[-1]
if top_ >= 0 and left_ >= 0 and top_ + height_ < h and left_ + width_ < w:
......@@ -1147,10 +1143,10 @@ def test_correctness_crop_segmentation_mask(device, top, left, height, width):
return expected
for mask in make_detection_and_segmentation_masks():
for mask in make_masks():
if mask.device != torch.device(device):
mask = mask.to(device)
output_mask = F.crop_segmentation_mask(mask, top, left, height, width)
output_mask = F.crop_mask(mask, top, left, height, width)
expected_mask = _compute_expected_mask(mask, top, left, height, width)
torch.testing.assert_close(output_mask, expected_mask)
......@@ -1160,7 +1156,7 @@ def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device):
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
mask[:, :, 0] = 1
out_mask = F.horizontal_flip_segmentation_mask(mask)
out_mask = F.horizontal_flip_mask(mask)
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, :, -1] = 1
......@@ -1172,7 +1168,7 @@ def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device):
mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
mask[:, 0, :] = 1
out_mask = F.vertical_flip_segmentation_mask(mask)
out_mask = F.vertical_flip_mask(mask)
expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device)
expected_mask[:, -1, :] = 1
......@@ -1233,7 +1229,7 @@ def test_correctness_resized_crop_bounding_box(device, format, top, left, height
[5, 5, 35, 45, (32, 34)],
],
)
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size):
def test_correctness_resized_crop_mask(device, top, left, height, width, size):
def _compute_expected_mask(mask, top_, left_, height_, width_, size_):
output = mask.clone()
output = output[:, top_ : top_ + height_, left_ : left_ + width_]
......@@ -1246,7 +1242,7 @@ def test_correctness_resized_crop_segmentation_mask(device, top, left, height, w
in_mask[0, 5:15, 12:23] = 2
expected_mask = _compute_expected_mask(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_mask(in_mask, top, left, height, width, size)
torch.testing.assert_close(output_mask, expected_mask)
......@@ -1310,7 +1306,7 @@ def test_correctness_pad_bounding_box(device, padding):
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
out_mask = F.pad_mask(mask, padding=[1, 1, 1, 1])
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
expected_mask[:, 1:-1, 1:-1] = 1
......@@ -1319,7 +1315,7 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
@pytest.mark.parametrize("padding_mode", ["constant", "edge", "reflect", "symmetric"])
def test_correctness_pad_segmentation_mask(padding, padding_mode):
def test_correctness_pad_mask(padding, padding_mode):
def _compute_expected_mask(mask, padding_, padding_mode_):
h, w = mask.shape[-2], mask.shape[-1]
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
......@@ -1367,8 +1363,8 @@ def test_correctness_pad_segmentation_mask(padding, padding_mode):
return output
for mask in make_detection_and_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode)
for mask in make_masks():
out_mask = F.pad_mask(mask, padding, padding_mode=padding_mode)
expected_mask = _compute_expected_mask(mask, padding, padding_mode)
torch.testing.assert_close(out_mask, expected_mask)
......@@ -1473,7 +1469,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
[[[3, 2], [32, 3], [30, 24], [2, 25]], [[5, 5], [30, 3], [33, 19], [4, 25]]],
],
)
def test_correctness_perspective_segmentation_mask(device, startpoints, endpoints):
def test_correctness_perspective_mask(device, startpoints, endpoints):
def _compute_expected_mask(mask, pcoeffs_):
assert mask.ndim == 3
m1 = np.array([[pcoeffs_[0], pcoeffs_[1], pcoeffs_[2]], [pcoeffs_[3], pcoeffs_[4], pcoeffs_[5]]])
......@@ -1500,7 +1496,7 @@ def test_correctness_perspective_segmentation_mask(device, startpoints, endpoint
for mask in make_detection_masks(extra_dims=((), (4,))):
mask = mask.to(device)
output_mask = F.perspective_segmentation_mask(
output_mask = F.perspective_mask(
mask,
perspective_coeffs=pcoeffs,
)
......@@ -1579,8 +1575,8 @@ def test_correctness_center_crop_bounding_box(device, output_size):
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("output_size", [[4, 2], [4], [7, 6]])
def test_correctness_center_crop_segmentation_mask(device, output_size):
def _compute_expected_segmentation_mask(mask, output_size):
def test_correctness_center_crop_mask(device, output_size):
def _compute_expected_mask(mask, output_size):
crop_height, crop_width = output_size if len(output_size) > 1 else [output_size[0], output_size[0]]
_, image_height, image_width = mask.shape
......@@ -1594,9 +1590,9 @@ def test_correctness_center_crop_segmentation_mask(device, output_size):
return mask[:, top : top + crop_height, left : left + crop_width]
mask = torch.randint(0, 2, size=(1, 6, 6), dtype=torch.long, device=device)
actual = F.center_crop_segmentation_mask(mask, output_size)
actual = F.center_crop_mask(mask, output_size)
expected = _compute_expected_segmentation_mask(mask, output_size)
expected = _compute_expected_mask(mask, output_size)
torch.testing.assert_close(expected, actual)
......@@ -1663,7 +1659,7 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
[
(F.elastic_image_tensor, make_images),
# FIXME: This test currently only works for "detection" masks. Extend it for "segmentation" masks.
(F.elastic_segmentation_mask, make_detection_masks),
(F.elastic_mask, make_detection_masks),
],
)
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
......@@ -1681,7 +1677,7 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
sample = features.Image(sample)
kwargs = {"interpolation": F.InterpolationMode.NEAREST}
else:
sample = features.SegmentationMask(sample)
sample = features.Mask(sample)
kwargs = {}
# Create a displacement grid using sin
......
......@@ -12,30 +12,30 @@ from torchvision.prototype.transforms.functional import to_image_pil
IMAGE = make_image(color_space=features.ColorSpace.RGB)
BOUNDING_BOX = make_bounding_box(format=features.BoundingBoxFormat.XYXY, image_size=IMAGE.image_size)
SEGMENTATION_MASK = make_detection_mask(size=IMAGE.image_size)
MASK = make_detection_mask(size=IMAGE.image_size)
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
((SEGMENTATION_MASK,), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.SegmentationMask), False),
((IMAGE,), (features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True),
((MASK,), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX,), (features.Image, features.Mask), False),
((IMAGE,), (features.BoundingBox, features.Mask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
(IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask),
True,
),
((), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
((), (features.Image, features.BoundingBox, features.Mask), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda obj: isinstance(obj, features.Image),), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
((IMAGE,), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((torch.Tensor(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
((to_image_pil(IMAGE),), (features.Image, PIL.Image.Image, features.is_simple_tensor), True),
......@@ -48,35 +48,35 @@ def test_has_any(sample, types, expected):
@pytest.mark.parametrize(
("sample", "types", "expected"),
[
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.SegmentationMask,), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Mask,), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), True),
((IMAGE, BOUNDING_BOX, MASK), (features.Image, features.Mask), True),
((IMAGE, BOUNDING_BOX, MASK), (features.BoundingBox, features.Mask), True),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
(IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.BoundingBox, features.SegmentationMask), False),
((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox), False),
((BOUNDING_BOX, MASK), (features.Image, features.Mask), False),
((IMAGE, MASK), (features.BoundingBox, features.Mask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(features.Image, features.BoundingBox, features.SegmentationMask),
(IMAGE, BOUNDING_BOX, MASK),
(features.Image, features.BoundingBox, features.Mask),
True,
),
((BOUNDING_BOX, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, SEGMENTATION_MASK), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.SegmentationMask), False),
((BOUNDING_BOX, MASK), (features.Image, features.BoundingBox, features.Mask), False),
((IMAGE, MASK), (features.Image, features.BoundingBox, features.Mask), False),
((IMAGE, BOUNDING_BOX), (features.Image, features.BoundingBox, features.Mask), False),
(
(IMAGE, BOUNDING_BOX, SEGMENTATION_MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.SegmentationMask)),),
(IMAGE, BOUNDING_BOX, MASK),
(lambda obj: isinstance(obj, (features.Image, features.BoundingBox, features.Mask)),),
True,
),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, SEGMENTATION_MASK), (lambda _: True,), True),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: False,), False),
((IMAGE, BOUNDING_BOX, MASK), (lambda _: True,), True),
],
)
def test_has_all(sample, types, expected):
......
......@@ -3,4 +3,4 @@ from ._encoded import EncodedData, EncodedImage, EncodedVideo
from ._feature import _Feature, is_simple_tensor
from ._image import ColorSpace, Image
from ._label import Label, OneHotLabel
from ._segmentation_mask import SegmentationMask
from ._mask import Mask
......@@ -8,14 +8,14 @@ from torchvision.transforms import InterpolationMode
from ._feature import _Feature
class SegmentationMask(_Feature):
def horizontal_flip(self) -> SegmentationMask:
output = self._F.horizontal_flip_segmentation_mask(self)
return SegmentationMask.new_like(self, output)
class Mask(_Feature):
def horizontal_flip(self) -> Mask:
output = self._F.horizontal_flip_mask(self)
return Mask.new_like(self, output)
def vertical_flip(self) -> SegmentationMask:
output = self._F.vertical_flip_segmentation_mask(self)
return SegmentationMask.new_like(self, output)
def vertical_flip(self) -> Mask:
output = self._F.vertical_flip_mask(self)
return Mask.new_like(self, output)
def resize( # type: ignore[override]
self,
......@@ -23,17 +23,17 @@ class SegmentationMask(_Feature):
interpolation: InterpolationMode = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: bool = False,
) -> SegmentationMask:
output = self._F.resize_segmentation_mask(self, size, max_size=max_size)
return SegmentationMask.new_like(self, output)
) -> Mask:
output = self._F.resize_mask(self, size, max_size=max_size)
return Mask.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> SegmentationMask:
output = self._F.crop_segmentation_mask(self, top, left, height, width)
return SegmentationMask.new_like(self, output)
def crop(self, top: int, left: int, height: int, width: int) -> Mask:
output = self._F.crop_mask(self, top, left, height, width)
return Mask.new_like(self, output)
def center_crop(self, output_size: List[int]) -> SegmentationMask:
output = self._F.center_crop_segmentation_mask(self, output_size=output_size)
return SegmentationMask.new_like(self, output)
def center_crop(self, output_size: List[int]) -> Mask:
output = self._F.center_crop_mask(self, output_size=output_size)
return Mask.new_like(self, output)
def resized_crop(
self,
......@@ -44,22 +44,22 @@ class SegmentationMask(_Feature):
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
antialias: bool = False,
) -> SegmentationMask:
output = self._F.resized_crop_segmentation_mask(self, top, left, height, width, size=size)
return SegmentationMask.new_like(self, output)
) -> Mask:
output = self._F.resized_crop_mask(self, top, left, height, width, size=size)
return Mask.new_like(self, output)
def pad(
self,
padding: Union[int, Sequence[int]],
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
padding_mode: str = "constant",
) -> SegmentationMask:
) -> Mask:
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
if not isinstance(padding, int):
padding = list(padding)
output = self._F.pad_segmentation_mask(self, padding, padding_mode=padding_mode)
return SegmentationMask.new_like(self, output)
output = self._F.pad_mask(self, padding, padding_mode=padding_mode)
return Mask.new_like(self, output)
def rotate(
self,
......@@ -68,9 +68,9 @@ class SegmentationMask(_Feature):
expand: bool = False,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
output = self._F.rotate_segmentation_mask(self, angle, expand=expand, center=center)
return SegmentationMask.new_like(self, output)
) -> Mask:
output = self._F.rotate_mask(self, angle, expand=expand, center=center)
return Mask.new_like(self, output)
def affine(
self,
......@@ -81,8 +81,8 @@ class SegmentationMask(_Feature):
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
center: Optional[List[float]] = None,
) -> SegmentationMask:
output = self._F.affine_segmentation_mask(
) -> Mask:
output = self._F.affine_mask(
self,
angle,
translate=translate,
......@@ -90,22 +90,22 @@ class SegmentationMask(_Feature):
shear=shear,
center=center,
)
return SegmentationMask.new_like(self, output)
return Mask.new_like(self, output)
def perspective(
self,
perspective_coeffs: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask:
output = self._F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output)
) -> Mask:
output = self._F.perspective_mask(self, perspective_coeffs)
return Mask.new_like(self, output)
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask:
output = self._F.elastic_segmentation_mask(self, displacement)
return SegmentationMask.new_like(self, output, dtype=output.dtype)
) -> Mask:
output = self._F.elastic_mask(self, displacement)
return Mask.new_like(self, output, dtype=output.dtype)
......@@ -108,10 +108,8 @@ class _BaseMixupCutmix(_RandomApplyTransform):
def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if has_any(inputs, features.BoundingBox, features.SegmentationMask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels."
)
if has_any(inputs, features.BoundingBox, features.Mask, features.Label):
raise TypeError(f"{type(self).__name__}() does not support bounding boxes, masks and plain labels.")
return super().forward(*inputs)
def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel:
......@@ -280,7 +278,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
def _extract_image_targets(self, flat_sample: List[Any]) -> Tuple[List[Any], List[Dict[str, Any]]]:
# fetch all images, bboxes, masks and labels from unstructured input
# with List[image], List[BoundingBox], List[SegmentationMask], List[Label]
# with List[image], List[BoundingBox], List[Mask], List[Label]
images, bboxes, masks, labels = [], [], [], []
for obj in flat_sample:
if isinstance(obj, features.Image) or features.is_simple_tensor(obj):
......@@ -289,7 +287,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
images.append(F.to_image_tensor(obj))
elif isinstance(obj, features.BoundingBox):
bboxes.append(obj)
elif isinstance(obj, features.SegmentationMask):
elif isinstance(obj, features.Mask):
masks.append(obj)
elif isinstance(obj, (features.Label, features.OneHotLabel)):
labels.append(obj)
......@@ -297,7 +295,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
if not (len(images) == len(bboxes) == len(masks) == len(labels)):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain equal sized list of Images, "
"BoundingBoxes, Segmentation Masks and Labels or OneHotLabels."
"BoundingBoxes, Masks and Labels or OneHotLabels."
)
targets = []
......@@ -323,8 +321,8 @@ class SimpleCopyPaste(_RandomApplyTransform):
elif isinstance(obj, features.BoundingBox):
flat_sample[i] = features.BoundingBox.new_like(obj, output_targets[c1]["boxes"])
c1 += 1
elif isinstance(obj, features.SegmentationMask):
flat_sample[i] = features.SegmentationMask.new_like(obj, output_targets[c2]["masks"])
elif isinstance(obj, features.Mask):
flat_sample[i] = features.Mask.new_like(obj, output_targets[c2]["masks"])
c2 += 1
elif isinstance(obj, (features.Label, features.OneHotLabel)):
flat_sample[i] = obj.new_like(obj, output_targets[c3]["labels"]) # type: ignore[arg-type]
......
......@@ -38,7 +38,7 @@ class _AutoAugmentBase(Transform):
def _extract_image(
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.Mask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
......
......@@ -170,8 +170,8 @@ class FiveCrop(Transform):
return F.five_crop(inpt, self.size)
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
......@@ -191,8 +191,8 @@ class TenCrop(Transform):
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
def forward(self, *inputs: Any) -> Any:
if has_any(inputs, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
if has_any(inputs, features.BoundingBox, features.Mask):
raise TypeError(f"BoundingBox'es and Mask's are not supported by {type(self).__name__}()")
return super().forward(*inputs)
......@@ -690,10 +690,10 @@ class RandomIoUCrop(Transform):
bboxes = output[is_within_crop_area]
bboxes = F.clamp_bounding_box(bboxes, output.format, output.image_size)
output = features.BoundingBox.new_like(output, bboxes)
elif isinstance(output, features.SegmentationMask) and output.shape[-3] > 1:
elif isinstance(output, features.Mask) and output.shape[-3] > 1:
# apply is_within_crop_area if mask is one-hot encoded
masks = output[is_within_crop_area]
output = features.SegmentationMask.new_like(output, masks)
output = features.Mask.new_like(output, masks)
return output
......@@ -705,7 +705,7 @@ class RandomIoUCrop(Transform):
):
raise TypeError(
f"{type(self).__name__}() requires input sample to contain Images or PIL Images, "
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Segmentation Masks."
"BoundingBoxes and Labels or OneHotLabels. Sample can also contain Masks."
)
return super().forward(*inputs)
......@@ -842,7 +842,7 @@ class FixedSizeCrop(Transform):
)
if params["is_valid"] is not None:
if isinstance(inpt, (features.Label, features.OneHotLabel, features.SegmentationMask)):
if isinstance(inpt, (features.Label, features.OneHotLabel, features.Mask)):
inpt = inpt.new_like(inpt, inpt[params["is_valid"]]) # type: ignore[arg-type]
elif isinstance(inpt, features.BoundingBox):
inpt = features.BoundingBox.new_like(
......
......@@ -150,7 +150,7 @@ class ToDtype(Lambda):
class RemoveSmallBoundingBoxes(Transform):
_transformed_types = (features.BoundingBox, features.SegmentationMask, features.Label, features.OneHotLabel)
_transformed_types = (features.BoundingBox, features.Mask, features.Label, features.OneHotLabel)
def __init__(self, min_size: float = 1.0) -> None:
super().__init__()
......
......@@ -53,22 +53,22 @@ from ._geometry import (
affine_bounding_box,
affine_image_pil,
affine_image_tensor,
affine_segmentation_mask,
affine_mask,
center_crop,
center_crop_bounding_box,
center_crop_image_pil,
center_crop_image_tensor,
center_crop_segmentation_mask,
center_crop_mask,
crop,
crop_bounding_box,
crop_image_pil,
crop_image_tensor,
crop_segmentation_mask,
crop_mask,
elastic,
elastic_bounding_box,
elastic_image_pil,
elastic_image_tensor,
elastic_segmentation_mask,
elastic_mask,
elastic_transform,
five_crop,
five_crop_image_pil,
......@@ -78,32 +78,32 @@ from ._geometry import (
horizontal_flip_bounding_box,
horizontal_flip_image_pil,
horizontal_flip_image_tensor,
horizontal_flip_segmentation_mask,
horizontal_flip_mask,
pad,
pad_bounding_box,
pad_image_pil,
pad_image_tensor,
pad_segmentation_mask,
pad_mask,
perspective,
perspective_bounding_box,
perspective_image_pil,
perspective_image_tensor,
perspective_segmentation_mask,
perspective_mask,
resize,
resize_bounding_box,
resize_image_pil,
resize_image_tensor,
resize_segmentation_mask,
resize_mask,
resized_crop,
resized_crop_bounding_box,
resized_crop_image_pil,
resized_crop_image_tensor,
resized_crop_segmentation_mask,
resized_crop_mask,
rotate,
rotate_bounding_box,
rotate_image_pil,
rotate_image_tensor,
rotate_segmentation_mask,
rotate_mask,
ten_crop,
ten_crop_image_pil,
ten_crop_image_tensor,
......@@ -111,7 +111,7 @@ from ._geometry import (
vertical_flip_bounding_box,
vertical_flip_image_pil,
vertical_flip_image_tensor,
vertical_flip_segmentation_mask,
vertical_flip_mask,
vflip,
)
from ._misc import gaussian_blur, gaussian_blur_image_pil, gaussian_blur_image_tensor, normalize, normalize_image_tensor
......
......@@ -28,8 +28,8 @@ horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip
def horizontal_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(segmentation_mask)
def horizontal_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return horizontal_flip_image_tensor(mask)
def horizontal_flip_bounding_box(
......@@ -61,8 +61,8 @@ vertical_flip_image_tensor = _FT.vflip
vertical_flip_image_pil = _FP.vflip
def vertical_flip_segmentation_mask(segmentation_mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(segmentation_mask)
def vertical_flip_mask(mask: torch.Tensor) -> torch.Tensor:
return vertical_flip_image_tensor(mask)
def vertical_flip_bounding_box(
......@@ -132,18 +132,14 @@ def resize_image_pil(
return _FP.resize(img, size, interpolation=pil_modes_mapping[interpolation])
def resize_segmentation_mask(
segmentation_mask: torch.Tensor, size: List[int], max_size: Optional[int] = None
) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = resize_image_tensor(
segmentation_mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size
)
output = resize_image_tensor(mask, size=size, interpolation=InterpolationMode.NEAREST, max_size=max_size)
if needs_squeeze:
output = output.squeeze(0)
......@@ -379,22 +375,22 @@ def affine_bounding_box(
).view(original_shape)
def affine_segmentation_mask(
segmentation_mask: torch.Tensor,
def affine_mask(
mask: torch.Tensor,
angle: float,
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = affine_image_tensor(
segmentation_mask,
mask,
angle=angle,
translate=translate,
scale=scale,
......@@ -545,20 +541,20 @@ def rotate_bounding_box(
).view(original_shape)
def rotate_segmentation_mask(
segmentation_mask: torch.Tensor,
def rotate_mask(
mask: torch.Tensor,
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = rotate_image_tensor(
segmentation_mask,
mask,
angle=angle,
expand=expand,
interpolation=InterpolationMode.NEAREST,
......@@ -639,16 +635,14 @@ def _pad_with_vector_fill(
return output
def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant"
) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
def pad_mask(mask: torch.Tensor, padding: Union[int, List[int]], padding_mode: str = "constant") -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = pad_image_tensor(img=segmentation_mask, padding=padding, fill=0, padding_mode=padding_mode)
output = pad_image_tensor(img=mask, padding=padding, fill=0, padding_mode=padding_mode)
if needs_squeeze:
output = output.squeeze(0)
......@@ -723,8 +717,8 @@ def crop_bounding_box(
)
def crop_segmentation_mask(img: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(img, top, left, height, width)
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(mask, top, left, height, width)
def crop(inpt: DType, top: int, left: int, height: int, width: int) -> DType:
......@@ -839,15 +833,15 @@ def perspective_bounding_box(
).view(original_shape)
def perspective_segmentation_mask(segmentation_mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
def perspective_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = perspective_image_tensor(
segmentation_mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST
)
if needs_squeeze:
......@@ -937,14 +931,14 @@ def elastic_bounding_box(
).view(original_shape)
def elastic_segmentation_mask(segmentation_mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
def elastic_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = elastic_image_tensor(segmentation_mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
output = elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST)
if needs_squeeze:
output = output.squeeze(0)
......@@ -1040,14 +1034,14 @@ def center_crop_bounding_box(
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)
def center_crop_segmentation_mask(segmentation_mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if segmentation_mask.ndim < 3:
segmentation_mask = segmentation_mask.unsqueeze(0)
def center_crop_mask(mask: torch.Tensor, output_size: List[int]) -> torch.Tensor:
if mask.ndim < 3:
mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False
output = center_crop_image_tensor(img=segmentation_mask, output_size=output_size)
output = center_crop_image_tensor(img=mask, output_size=output_size)
if needs_squeeze:
output = output.squeeze(0)
......@@ -1104,7 +1098,7 @@ def resized_crop_bounding_box(
return resize_bounding_box(bounding_box, size, (height, width))
def resized_crop_segmentation_mask(
def resized_crop_mask(
mask: torch.Tensor,
top: int,
left: int,
......@@ -1112,8 +1106,8 @@ def resized_crop_segmentation_mask(
width: int,
size: List[int],
) -> torch.Tensor:
mask = crop_segmentation_mask(mask, top, left, height, width)
return resize_segmentation_mask(mask, size)
mask = crop_mask(mask, top, left, height, width)
return resize_mask(mask, size)
def resized_crop(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment