Unverified Commit 91b44590 authored by Willie Maddox's avatar Willie Maddox Committed by GitHub
Browse files

Add Perspective fill option (#1973)

* Add fill option to RandomPerspective #1972

* Minor fix to docstring syntax

* Add _parse_fill() to get fillcolor (#1972)

* Minor refactoring as per comments.

* Added test for RandomPerspective with fillcolor.

* Force perspective transform in test.
parent e61b68e0
...@@ -177,6 +177,41 @@ class Tester(unittest.TestCase): ...@@ -177,6 +177,41 @@ class Tester(unittest.TestCase):
self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3, self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3,
torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img))) torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
def test_randomperspective_fill(self):
height = 100
width = 100
img = torch.ones(3, height, width)
to_pil_image = transforms.ToPILImage()
img = to_pil_image(img)
modes = ("L", "RGB", "F")
nums_bands = [len(mode) for mode in modes]
fill = 127
for mode, num_bands in zip(modes, nums_bands):
img_conv = img.convert(mode)
perspective = transforms.RandomPerspective(p=1, fill=fill)
tr_img = perspective(img_conv)
pixel = tr_img.getpixel((0, 0))
if not isinstance(pixel, tuple):
pixel = (pixel,)
self.assertTupleEqual(pixel, tuple([fill] * num_bands))
for mode, num_bands in zip(modes, nums_bands):
img_conv = img.convert(mode)
startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
pixel = tr_img.getpixel((0, 0))
if not isinstance(pixel, tuple):
pixel = (pixel,)
self.assertTupleEqual(pixel, tuple([fill] * num_bands))
for wrong_num_bands in set(nums_bands) - {num_bands}:
with self.assertRaises(ValueError):
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
def test_resize(self): def test_resize(self):
height = random.randint(24, 32) * 2 height = random.randint(24, 32) * 2
width = random.randint(24, 32) * 2 width = random.randint(24, 32) * 2
......
...@@ -425,6 +425,41 @@ def hflip(img): ...@@ -425,6 +425,41 @@ def hflip(img):
return img.transpose(Image.FLIP_LEFT_RIGHT) return img.transpose(Image.FLIP_LEFT_RIGHT)
def _parse_fill(fill, img, min_pil_version):
"""Helper function to get the fill color for rotate and perspective transforms.
Args:
fill (n-tuple or int or float): Pixel fill value for area outside the transformed
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands.
img (PIL Image): Image to be filled.
min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option
was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0)
Returns:
dict: kwarg for ``fillcolor``
"""
if PILLOW_VERSION < min_pil_version:
if fill is None:
return {}
else:
msg = ("The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
return {"fillcolor": fill}
def _get_perspective_coeffs(startpoints, endpoints): def _get_perspective_coeffs(startpoints, endpoints):
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms. """Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
...@@ -450,7 +485,7 @@ def _get_perspective_coeffs(startpoints, endpoints): ...@@ -450,7 +485,7 @@ def _get_perspective_coeffs(startpoints, endpoints):
return res.squeeze_(1).tolist() return res.squeeze_(1).tolist()
def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC): def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=None):
"""Perform perspective transform of the given PIL Image. """Perform perspective transform of the given PIL Image.
Args: Args:
...@@ -458,14 +493,21 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC): ...@@ -458,14 +493,21 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC):
startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image
endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
interpolation: Default- Image.BICUBIC interpolation: Default- Image.BICUBIC
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
This option is only available for ``pillow>=5.0.0``.
Returns: Returns:
PIL Image: Perspectively transformed Image. PIL Image: Perspectively transformed Image.
""" """
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = _parse_fill(fill, img, '5.0.0')
coeffs = _get_perspective_coeffs(startpoints, endpoints) coeffs = _get_perspective_coeffs(startpoints, endpoints)
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation) return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
def vflip(img): def vflip(img):
...@@ -721,30 +763,10 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None): ...@@ -721,30 +763,10 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters .. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
""" """
def parse_fill(fill, num_bands):
if PILLOW_VERSION < "5.2.0":
if fill is None:
return {}
else:
msg = ("The option to fill background area of the rotated image, "
"requires pillow>=5.2.0")
raise RuntimeError(msg)
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
return {"fillcolor": fill}
if not _is_pil_image(img): if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = parse_fill(fill, len(img.getbands())) opts = _parse_fill(fill, img, '5.2.0')
return img.rotate(angle, resample, expand, center, **opts) return img.rotate(angle, resample, expand, center, **opts)
......
...@@ -550,12 +550,15 @@ class RandomPerspective(object): ...@@ -550,12 +550,15 @@ class RandomPerspective(object):
distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5. distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively. Default value is 0.
""" """
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC): def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
self.p = p self.p = p
self.interpolation = interpolation self.interpolation = interpolation
self.distortion_scale = distortion_scale self.distortion_scale = distortion_scale
self.fill = fill
def __call__(self, img): def __call__(self, img):
""" """
...@@ -571,7 +574,7 @@ class RandomPerspective(object): ...@@ -571,7 +574,7 @@ class RandomPerspective(object):
if random.random() < self.p: if random.random() < self.p:
width, height = img.size width, height = img.size
startpoints, endpoints = self.get_params(width, height, self.distortion_scale) startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation) return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
return img return img
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment