Unverified Commit 6fe11d5c authored by vfdev's avatar vfdev Committed by GitHub
Browse files

Issue 2350 - support of all padding modes with tensors (#2368)

* [WIP] functional_tensor supports more padding modes

* [WIP] Support all padding modes

* Removed wip symmetric mode

* Improvements according to the review
parent a99b6bd7
...@@ -247,18 +247,36 @@ class Tester(unittest.TestCase): ...@@ -247,18 +247,36 @@ class Tester(unittest.TestCase):
def test_pad(self): def test_pad(self):
script_fn = torch.jit.script(F_t.pad) script_fn = torch.jit.script(F_t.pad)
tensor, pil_img = self._create_data(7, 8) tensor, pil_img = self._create_data(7, 8)
for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]:
padding_mode = "constant" for dt in [None, torch.float32, torch.float64]:
for fill in [0, 10, 20]: if dt is not None:
pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode) # This is a trivial cast to float of uint8 data to test all cases
pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode) tensor = tensor.to(dt)
self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill)) for pad in [2, [3, ], [0, 3], (3, 3), [4, 2, 4, 3]]:
if isinstance(pad, int): configs = [
script_pad = [pad, ] {"padding_mode": "constant", "fill": 0},
else: {"padding_mode": "constant", "fill": 10},
script_pad = pad {"padding_mode": "constant", "fill": 20},
pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode) {"padding_mode": "edge"},
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill)) {"padding_mode": "reflect"},
]
for kwargs in configs:
pad_tensor = F_t.pad(tensor, pad, **kwargs)
pad_pil_img = F_pil.pad(pil_img, pad, **kwargs)
pad_tensor_8b = pad_tensor
# we need to cast to uint8 to compare with PIL image
if pad_tensor_8b.dtype != torch.uint8:
pad_tensor_8b = pad_tensor_8b.to(torch.uint8)
self.compareTensorToPIL(pad_tensor_8b, pad_pil_img, msg="{}, {}".format(pad, kwargs))
if isinstance(pad, int):
script_pad = [pad, ]
else:
script_pad = pad
pad_tensor_script = script_fn(tensor, script_pad, **kwargs)
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, kwargs))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -346,7 +346,7 @@ def _hsv2rgb(img): ...@@ -346,7 +346,7 @@ def _hsv2rgb(img):
return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4) return torch.einsum("ijk, xijk -> xjk", mask.to(dtype=img.dtype), a4)
def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constant") -> Tensor: def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
r"""Pad the given Tensor Image on all sides with specified padding mode and fill value. r"""Pad the given Tensor Image on all sides with specified padding mode and fill value.
Args: Args:
...@@ -363,6 +363,13 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan ...@@ -363,6 +363,13 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
- constant: pads with a constant value, this value is specified with fill - constant: pads with a constant value, this value is specified with fill
- edge: pads with the last value on the edge of the image
- reflect: pads with reflection of image (without repeating the last value on the edge)
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
will result in [3, 2, 1, 2, 3, 4, 3, 2]
Returns: Returns:
Tensor: Padded image. Tensor: Padded image.
""" """
...@@ -383,8 +390,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan ...@@ -383,8 +390,8 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " + raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding))) "{} element tuple".format(len(padding)))
if padding_mode not in ["constant", ]: if padding_mode not in ["constant", "edge", "reflect"]:
raise ValueError("Only constant padding_mode supported for torch tensors") raise ValueError("Padding mode should be either constant, edge or reflect")
if isinstance(padding, int): if isinstance(padding, int):
if torch.jit.is_scripting(): if torch.jit.is_scripting():
...@@ -403,5 +410,30 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan ...@@ -403,5 +410,30 @@ def pad(img: Tensor, padding: List[int], fill: int, padding_mode: str = "constan
p = [pad_left, pad_right, pad_top, pad_bottom] p = [pad_left, pad_right, pad_top, pad_bottom]
if padding_mode == "edge":
# remap padding_mode str
padding_mode = "replicate"
need_squeeze = False
if img.ndim < 4:
img = img.unsqueeze(dim=0)
need_squeeze = True
out_dtype = img.dtype
need_cast = False
if (padding_mode != "constant") and img.dtype not in (torch.float32, torch.float64):
# Here we temporary cast input tensor to float
# until pytorch issue is resolved :
# https://github.com/pytorch/pytorch/issues/40763
need_cast = True
img = img.to(torch.float32)
img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill)) img = torch.nn.functional.pad(img, p, mode=padding_mode, value=float(fill))
if need_squeeze:
img = img.squeeze(dim=0)
if need_cast:
img = img.to(out_dtype)
return img return img
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment