Unverified Commit 104073cc authored by Federico Pozzi's avatar Federico Pozzi Committed by GitHub
Browse files

feat: add functional pad on segmentation mask (#5866)



* feat: add functional pad on segmentation mask

* test: add basic correctness test with random masks

* test: add all padding options

* fix: pr comments

* fix: tests

* refactor: reshape tensor in 4d, then pad
Co-authored-by: default avatarFederico Pozzi <federico.pozzi@argo.vision>
parent ecbff88a
......@@ -370,6 +370,16 @@ def resized_crop_segmentation_mask():
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)
@register_kernel_info_from_sample_inputs_fn
def pad_segmentation_mask():
for mask, padding, padding_mode in itertools.product(
make_segmentation_masks(),
[[1], [1, 1], [1, 1, 2, 2]], # padding
["constant", "symmetric", "edge", "reflect"], # padding mode,
):
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
@pytest.mark.parametrize(
"kernel",
[
......@@ -1031,3 +1041,47 @@ def test_correctness_resized_crop_segmentation_mask(device, top, left, height, w
expected_mask = _compute_expected(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
torch.testing.assert_close(output_mask, expected_mask)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
expected_mask[:, 1:-1, 1:-1] = 1
torch.testing.assert_close(out_mask, expected_mask)
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
def test_correctness_pad_segmentation_mask(padding):
def _compute_expected_mask():
def parse_padding():
if isinstance(padding, int):
return [padding] * 4
if isinstance(padding, list):
if len(padding) == 1:
return padding * 4
if len(padding) == 2:
return padding * 2 # [left, up, right, down]
return padding
h, w = mask.shape[-2], mask.shape[-1]
pad_left, pad_up, pad_right, pad_down = parse_padding()
new_h = h + pad_up + pad_down
new_w = w + pad_left + pad_right
new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w)
expected_mask = torch.zeros(new_shape, dtype=torch.long)
expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask
return expected_mask
for mask in make_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
expected_mask = _compute_expected_mask()
torch.testing.assert_close(out_mask, expected_mask)
......@@ -62,6 +62,7 @@ from ._geometry import (
pad_bounding_box,
pad_image_tensor,
pad_image_pil,
pad_segmentation_mask,
crop_bounding_box,
crop_image_tensor,
crop_image_pil,
......
......@@ -396,6 +396,20 @@ pad_image_tensor = _FT.pad
pad_image_pil = _FP.pad
def pad_segmentation_mask(
segmentation_mask: torch.Tensor, padding: List[int], padding_mode: str = "constant"
) -> torch.Tensor:
num_masks, height, width = segmentation_mask.shape[-3:]
extra_dims = segmentation_mask.shape[:-3]
padded_mask = pad_image_tensor(
img=segmentation_mask.view(-1, num_masks, height, width), padding=padding, fill=0, padding_mode=padding_mode
)
new_height, new_width = padded_mask.shape[-2:]
return padded_mask.view(extra_dims + (num_masks, new_height, new_width))
def pad_bounding_box(
bounding_box: torch.Tensor, padding: List[int], format: features.BoundingBoxFormat
) -> torch.Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment