Unverified Commit 6fe11d5c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Issue 2350 - support of all padding modes with tensors (#2368)

* [WIP] functional_tensor supports more padding modes

* [WIP] Support all padding modes

* Removed wip symmetric mode

* Improvements according to the review
parent a99b6bd7
......@@ -247,18 +247,36 @@ class Tester(unittest.TestCase):
def test_pad(self):
script_fn = torch.jit.script(F_t.pad)
tensor, pil_img = self._create_data(7, 8)
for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]:
padding_mode = "constant"
for fill in [0, 10, 20]:
pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode)
pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode)
self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill))
if isinstance(pad, int):
script_pad = [pad, ]
else:
script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill))
for dt in [None, torch.float32, torch.float64]:
if dt is not None:
# This is a trivial cast to float of uint8 data to test all cases
tensor = tensor.to(dt)
for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
configs = [
{"padding_mode": "constant", "fill": 0},
{"padding_mode": "constant", "fill": 10},
{"padding_mode": "constant", "fill": 20},
{"padding_mode": "edge"},
{"padding_mode": "reflect"},
]
for kwargs in configs:
pad_tensor = F_t.pad(tensor, pad, **kwargs)
pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)
pad_tensor_8b = pad_tensor
# we need to cast to uint8 to compare with PIL image
if pad_tensor_8b.dtype != torch.uint8:
pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))
if isinstance(pad, int):
script_pad = [pad, ]
else:
script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
if __name__ == '__main__':
......
......@@ -346,7 +346,7 @@ def _hsv2rgb(img):
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor:
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
Args:
......@@ -363,6 +363,13 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
- constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value on the edge of the image
- reflect: pads with reflection of image (without repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
Returns:
Tensor: Padded image.
"""
......@@ -383,8 +390,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
if padding_mode not in ["constant", ]:
raise ValueError("Only constant padding_mode supported for torch tensors")
if padding_mode not in ["constant", "edge", "reflect"]:
raise ValueError("Padding mode should be either constant, edge or reflect")
if isinstance(padding, int):
if torch.jit.is_scripting():
......@@ -403,5 +410,30 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
p = [pad_left, pad_right, pad_top, pad_bottom]
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment