"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "107f18438605884f580a656fc980233125382309"
Commit 8e375670 authored by Alykhan Tejani's avatar Alykhan Tejani Committed by Soumith Chintala
Browse files

improve docs for Pad + add tests (#239)

parent eec5ba44
...@@ -136,6 +136,30 @@ class Tester(unittest.TestCase): ...@@ -136,6 +136,30 @@ class Tester(unittest.TestCase):
assert result.size(1) == height + 2 * padding assert result.size(1) == height + 2 * padding
assert result.size(2) == width + 2 * padding assert result.size(2) == width + 2 * padding
def test_pad_with_tuple_of_pad_values(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
img = transforms.ToPILImage()(torch.ones(3, height, width))
padding = tuple([random.randint(1, 20) for _ in range(2)])
output = transforms.Pad(padding)(img)
assert output.size == (width + padding[0] * 2, height + padding[1] * 2)
padding = tuple([random.randint(1, 20) for _ in range(4)])
output = transforms.Pad(padding)(img)
assert output.size[0] == width + padding[0] + padding[2]
assert output.size[1] == height + padding[1] + padding[3]
def test_pad_raises_with_invalide_pad_sequence_len(self):
with self.assertRaises(ValueError):
transforms.Pad(())
with self.assertRaises(ValueError):
transforms.Pad((1, 2, 3))
with self.assertRaises(ValueError):
transforms.Pad((1, 2, 3, 4, 5))
def test_lambda(self): def test_lambda(self):
trans = transforms.Lambda(lambda x: x.add(10)) trans = transforms.Lambda(lambda x: x.add(10))
x = torch.randn(10) x = torch.randn(10)
......
...@@ -233,15 +233,22 @@ class Pad(object): ...@@ -233,15 +233,22 @@ class Pad(object):
"""Pad the given PIL.Image on all sides with the given "pad" value. """Pad the given PIL.Image on all sides with the given "pad" value.
Args: Args:
padding (int or sequence): Padding on each border. If a sequence of padding (int or tuple): Padding on each border. If a single int is provided this
length 4, it is used to pad left, top, right and bottom borders respectively. is used to pad all borders. If tuple of length 2 is provided this is the padding
fill: Pixel fill value. Default is 0. If a sequence of on left/right and top/bottom respectively. If a tuple of length 4 is provided
this is the padding for the left, top, right and bottom borders
respectively.
fill: Pixel fill value. Default is 0. If a tuple of
length 3, it is used to fill R, G, B channels respectively. length 3, it is used to fill R, G, B channels respectively.
""" """
def __init__(self, padding, fill=0): def __init__(self, padding, fill=0):
assert isinstance(padding, numbers.Number) or isinstance(padding, tuple) assert isinstance(padding, (numbers.Number, tuple))
assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) assert isinstance(fill, (numbers.Number, str, tuple))
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
"{} element tuple".format(len(padding)))
self.padding = padding self.padding = padding
self.fill = fill self.fill = fill
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment