Unverified Commit fa37d9bb authored by vfdev's avatar vfdev Committed by GitHub
Browse files

[proto] Added tests for other padding modes (#6104)

* Added tests for other padding modes

* Fixed expected mask dtype

* Applied comments from review
parent d5929257
......@@ -1101,17 +1101,6 @@ def test_correctness_resized_crop_segmentation_mask(device, top, left, height, w
torch.testing.assert_close(output_mask, expected_mask)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
expected_mask[:, 1:-1, 1:-1] = 1
torch.testing.assert_close(out_mask, expected_mask)
def _parse_padding(padding):
if isinstance(padding, int):
return [padding] * 4
......@@ -1168,25 +1157,71 @@ def test_correctness_pad_bounding_box(device, padding):
torch.testing.assert_close(output_boxes, expected_bboxes)
@pytest.mark.parametrize("device", cpu_and_gpu())
def test_correctness_pad_segmentation_mask_on_fixed_input(device):
mask = torch.ones((1, 3, 3), dtype=torch.long, device=device)
out_mask = F.pad_segmentation_mask(mask, padding=[1, 1, 1, 1])
expected_mask = torch.zeros((1, 5, 5), dtype=torch.long, device=device)
expected_mask[:, 1:-1, 1:-1] = 1
torch.testing.assert_close(out_mask, expected_mask)
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
def test_correctness_pad_segmentation_mask(padding):
def _compute_expected_mask(mask, padding_):
@pytest.mark.parametrize("padding_mode", ["constant", "edge", "reflect", "symmetric"])
def test_correctness_pad_segmentation_mask(padding, padding_mode):
def _compute_expected_mask(mask, padding_, padding_mode_):
h, w = mask.shape[-2], mask.shape[-1]
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
if any(pad <= 0 for pad in [pad_left, pad_up, pad_right, pad_down]):
raise pytest.UsageError(
"Expected output can be computed on positive pad values only, "
"but F.pad_* can also crop for negative values"
)
new_h = h + pad_up + pad_down
new_w = w + pad_left + pad_right
new_shape = (*mask.shape[:-2], new_h, new_w) if len(mask.shape) > 2 else (new_h, new_w)
expected_mask = torch.zeros(new_shape, dtype=torch.long)
expected_mask[..., pad_up:-pad_down, pad_left:-pad_right] = mask
output = torch.zeros(new_shape, dtype=mask.dtype)
output[..., pad_up:-pad_down, pad_left:-pad_right] = mask
if padding_mode_ == "edge":
# pad top-left corner, left vertical block, bottom-left corner
output[..., :pad_up, :pad_left] = mask[..., 0, 0].unsqueeze(-1).unsqueeze(-2)
output[..., pad_up:-pad_down, :pad_left] = mask[..., :, 0].unsqueeze(-1)
output[..., -pad_down:, :pad_left] = mask[..., -1, 0].unsqueeze(-1).unsqueeze(-2)
# pad top-right corner, right vertical block, bottom-right corner
output[..., :pad_up, -pad_right:] = mask[..., 0, -1].unsqueeze(-1).unsqueeze(-2)
output[..., pad_up:-pad_down, -pad_right:] = mask[..., :, -1].unsqueeze(-1)
output[..., -pad_down:, -pad_right:] = mask[..., -1, -1].unsqueeze(-1).unsqueeze(-2)
# pad top and bottom horizontal blocks
output[..., :pad_up, pad_left:-pad_right] = mask[..., 0, :].unsqueeze(-2)
output[..., -pad_down:, pad_left:-pad_right] = mask[..., -1, :].unsqueeze(-2)
elif padding_mode_ in ("reflect", "symmetric"):
d1 = 1 if padding_mode_ == "reflect" else 0
d2 = -1 if padding_mode_ == "reflect" else None
both = (-1, -2)
# pad top-left corner, left vertical block, bottom-left corner
output[..., :pad_up, :pad_left] = mask[..., d1 : pad_up + d1, d1 : pad_left + d1].flip(both)
output[..., pad_up:-pad_down, :pad_left] = mask[..., :, d1 : pad_left + d1].flip(-1)
output[..., -pad_down:, :pad_left] = mask[..., -pad_down - d1 : d2, d1 : pad_left + d1].flip(both)
# pad top-right corner, right vertical block, bottom-right corner
output[..., :pad_up, -pad_right:] = mask[..., d1 : pad_up + d1, -pad_right - d1 : d2].flip(both)
output[..., pad_up:-pad_down, -pad_right:] = mask[..., :, -pad_right - d1 : d2].flip(-1)
output[..., -pad_down:, -pad_right:] = mask[..., -pad_down - d1 : d2, -pad_right - d1 : d2].flip(both)
# pad top and bottom horizontal blocks
output[..., :pad_up, pad_left:-pad_right] = mask[..., d1 : pad_up + d1, :].flip(-2)
output[..., -pad_down:, pad_left:-pad_right] = mask[..., -pad_down - d1 : d2, :].flip(-2)
return expected_mask
return output
for mask in make_segmentation_masks():
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
out_mask = F.pad_segmentation_mask(mask, padding, padding_mode=padding_mode)
expected_mask = _compute_expected_mask(mask, padding)
expected_mask = _compute_expected_mask(mask, padding, padding_mode)
torch.testing.assert_close(out_mask, expected_mask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment