Unverified Commit 39e4057c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Added symmetric padding mode for Tensors (#2373)

* [WIP] Added symmetric padding mode

* Added check and raise error if padding is negative for symmetric padding mode

* Added test check for raising error if negative pad
parent e4b9823f
......@@ -259,6 +259,7 @@ class Tester(unittest.TestCase):
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
{"padding_mode": "symmetric"},
]
for kwargs in configs:
pad_tensor = F_t.pad(tensor, pad, **kwargs)
......@@ -278,6 +279,9 @@ class Tester(unittest.TestCase):
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
if __name__ == '__main__':
unittest.main()
......@@ -355,6 +355,29 @@ def _hsv2rgb(img):
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
def _pad_symmetric(img: Tensor, padding: List[int]) -> Tensor:
# padding is left, right, top, bottom
in_sizes = img.size()
x_indices = [i for i in range(in_sizes[-1])] # [0, 1, 2, 3, ...]
left_indices = [i for i in range(padding[0] - 1, -1, -1)] # e.g. [3, 2, 1, 0]
right_indices = [-(i + 1) for i in range(padding[1])] # e.g. [-1, -2, -3]
x_indices = torch.tensor(left_indices + x_indices + right_indices)
y_indices = [i for i in range(in_sizes[-2])]
top_indices = [i for i in range(padding[2] - 1, -1, -1)]
bottom_indices = [-(i + 1) for i in range(padding[3])]
y_indices = torch.tensor(top_indices + y_indices + bottom_indices)
ndim = img.ndim
if ndim == 3:
return img[:, y_indices[:, None], x_indices[None, :]]
elif ndim == 4:
return img[:, :, y_indices[:, None], x_indices[None, :]]
else:
raise RuntimeError("Symmetric padding of N-D tensors are not supported yet")
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
......@@ -380,6 +403,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
- symmetric: pads with reflection of image (repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
will result in [2, 1, 1, 2, 3, 4, 4, 3]
Returns:
Tensor: Padded image.
"""
......@@ -400,8 +428,8 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
if padding_mode not in ["constant", "edge", "reflect"]:
raise ValueError("Padding mode should be either constant, edge or reflect")
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
if isinstance(padding, int):
if torch.jit.is_scripting():
......@@ -423,6 +451,11 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
elif padding_mode == "symmetric":
# route to another implementation
if p[0] < 0 or p[1] < 0 or p[2] < 0 or p[3] < 0: # no any support for torch script
raise ValueError("Padding can not be negative for symmetric padding_mode")
return _pad_symmetric(img, p)
need_squeeze = False
if img.ndim < 4:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment