Unverified Commit 8f61f4ca authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Better handling for Pad's fill argument (#5596)

parent 9acec20d
...@@ -452,12 +452,12 @@ def test_resize_size_equals_small_edge_size(height, width): ...@@ -452,12 +452,12 @@ def test_resize_size_equals_small_edge_size(height, width):
class TestPad: class TestPad:
def test_pad(self): @pytest.mark.parametrize("fill", [85, 85.0])
def test_pad(self, fill):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
img = torch.ones(3, height, width, dtype=torch.uint8) img = torch.ones(3, height, width, dtype=torch.uint8)
padding = random.randint(1, 20) padding = random.randint(1, 20)
fill = random.randint(1, 50)
result = transforms.Compose( result = transforms.Compose(
[ [
transforms.ToPILImage(), transforms.ToPILImage(),
...@@ -484,7 +484,7 @@ class TestPad: ...@@ -484,7 +484,7 @@ class TestPad:
output = transforms.Pad(padding)(img) output = transforms.Pad(padding)(img)
assert output.size == (width + padding[0] * 2, height + padding[1] * 2) assert output.size == (width + padding[0] * 2, height + padding[1] * 2)
padding = tuple(random.randint(1, 20) for _ in range(4)) padding = [random.randint(1, 20) for _ in range(4)]
output = transforms.Pad(padding)(img) output = transforms.Pad(padding)(img)
assert output.size[0] == width + padding[0] + padding[2] assert output.size[0] == width + padding[0] + padding[2]
assert output.size[1] == height + padding[1] + padding[3] assert output.size[1] == height + padding[1] + padding[3]
......
...@@ -154,7 +154,7 @@ def pad( ...@@ -154,7 +154,7 @@ def pad(
if not isinstance(padding, (numbers.Number, tuple, list)): if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg") raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)): if not isinstance(fill, (numbers.Number, str, tuple, list)):
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg")
if not isinstance(padding_mode, str): if not isinstance(padding_mode, str):
raise TypeError("Got inappropriate padding_mode arg") raise TypeError("Got inappropriate padding_mode arg")
...@@ -301,6 +301,12 @@ def _parse_fill( ...@@ -301,6 +301,12 @@ def _parse_fill(
fill = tuple(fill) fill = tuple(fill)
if img.mode != "F":
if isinstance(fill, (list, tuple)):
fill = tuple(int(x) for x in fill)
else:
fill = int(fill)
return {name: fill} return {name: fill}
......
...@@ -428,7 +428,7 @@ class Pad(torch.nn.Module): ...@@ -428,7 +428,7 @@ class Pad(torch.nn.Module):
if not isinstance(padding, (numbers.Number, tuple, list)): if not isinstance(padding, (numbers.Number, tuple, list)):
raise TypeError("Got inappropriate padding arg") raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)): if not isinstance(fill, (numbers.Number, str, tuple, list)):
raise TypeError("Got inappropriate fill arg") raise TypeError("Got inappropriate fill arg")
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment