Unverified Commit 37a0d8d6 authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

[BC-breaking] Fix for integer fill value in constant padding (#2284)

* Bugfix in pad

* Address review comments

* Fix lint
parent 39021408
...@@ -299,13 +299,22 @@ class Tester(unittest.TestCase): ...@@ -299,13 +299,22 @@ class Tester(unittest.TestCase):
width = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2
img = torch.ones(3, height, width) img = torch.ones(3, height, width)
padding = random.randint(1, 20) padding = random.randint(1, 20)
fill = random.randint(1, 50)
result = transforms.Compose([ result = transforms.Compose([
transforms.ToPILImage(), transforms.ToPILImage(),
transforms.Pad(padding), transforms.Pad(padding, fill=fill),
transforms.ToTensor(), transforms.ToTensor(),
])(img) ])(img)
self.assertEqual(result.size(1), height + 2 * padding) self.assertEqual(result.size(1), height + 2 * padding)
self.assertEqual(result.size(2), width + 2 * padding) self.assertEqual(result.size(2), width + 2 * padding)
# check that all elements in the padded region correspond
# to the pad value
fill_v = fill / 255
eps = 1e-5
self.assertTrue((result[:, :padding, :] - fill_v).abs().max() < eps)
self.assertTrue((result[:, :, :padding] - fill_v).abs().max() < eps)
self.assertRaises(ValueError, transforms.Pad(padding, fill=(1, 2)),
transforms.ToPILImage()(img))
def test_pad_with_tuple_of_pad_values(self): def test_pad_with_tuple_of_pad_values(self):
height = random.randint(10, 32) * 2 height = random.randint(10, 32) * 2
......
...@@ -329,6 +329,12 @@ def pad(img, padding, fill=0, padding_mode='constant'): ...@@ -329,6 +329,12 @@ def pad(img, padding, fill=0, padding_mode='constant'):
'Padding mode should be either constant, edge, reflect or symmetric' 'Padding mode should be either constant, edge, reflect or symmetric'
if padding_mode == 'constant': if padding_mode == 'constant':
if isinstance(fill, numbers.Number):
fill = (fill,) * len(img.getbands())
if len(fill) != len(img.getbands()):
raise ValueError('fill should have the same number of elements '
'as the number of channels in the image '
'({}), got {} instead'.format(len(img.getbands()), len(fill)))
if img.mode == 'P': if img.mode == 'P':
palette = img.getpalette() palette = img.getpalette()
image = ImageOps.expand(img, border=padding, fill=fill) image = ImageOps.expand(img, border=padding, fill=fill)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment