Unverified Commit 91b44590 authored by Willie Maddox's avatar Willie Maddox Committed by GitHub
Browse files

Add Perspective fill option (#1973)

* Add fill option to RandomPerspective #1972

* Minor fix to docstring syntax

* Add _parse_fill() to get fillcolor (#1972)

* Minor refactoring as per comments.

* Added test for RandomPerspective with fillcolor.

* Force perspective transform in test.
parent e61b68e0
......@@ -177,6 +177,41 @@ class Tester(unittest.TestCase):
self.assertGreater(torch.nn.functional.mse_loss(tr_img, F.to_tensor(img)) + 0.3,
torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
def test_randomperspective_fill(self):
height = 100
width = 100
img = torch.ones(3, height, width)
to_pil_image = transforms.ToPILImage()
img = to_pil_image(img)
modes = ("L", "RGB", "F")
nums_bands = [len(mode) for mode in modes]
fill = 127
for mode, num_bands in zip(modes, nums_bands):
img_conv = img.convert(mode)
perspective = transforms.RandomPerspective(p=1, fill=fill)
tr_img = perspective(img_conv)
pixel = tr_img.getpixel((0, 0))
if not isinstance(pixel, tuple):
pixel = (pixel,)
self.assertTupleEqual(pixel, tuple([fill] * num_bands))
for mode, num_bands in zip(modes, nums_bands):
img_conv = img.convert(mode)
startpoints, endpoints = transforms.RandomPerspective.get_params(width, height, 0.5)
tr_img = F.perspective(img_conv, startpoints, endpoints, fill=fill)
pixel = tr_img.getpixel((0, 0))
if not isinstance(pixel, tuple):
pixel = (pixel,)
self.assertTupleEqual(pixel, tuple([fill] * num_bands))
for wrong_num_bands in set(nums_bands) - {num_bands}:
with self.assertRaises(ValueError):
F.perspective(img_conv, startpoints, endpoints, fill=tuple([fill] * wrong_num_bands))
def test_resize(self):
height = random.randint(24, 32) * 2
width = random.randint(24, 32) * 2
......
......@@ -425,6 +425,41 @@ def hflip(img):
return img.transpose(Image.FLIP_LEFT_RIGHT)
def _parse_fill(fill, img, min_pil_version):
"""Helper function to get the fill color for rotate and perspective transforms.
Args:
fill (n-tuple or int or float): Pixel fill value for area outside the transformed
image. If int or float, the value is used for all bands respectively.
Defaults to 0 for all bands.
img (PIL Image): Image to be filled.
min_pil_version (str): The minimum PILLOW version for when the ``fillcolor`` option
was first introduced in the calling function. (e.g. rotate->5.2.0, perspective->5.0.0)
Returns:
dict: kwarg for ``fillcolor``
"""
if PILLOW_VERSION < min_pil_version:
if fill is None:
return {}
else:
msg = ("The option to fill background area of the transformed image, "
"requires pillow>={}")
raise RuntimeError(msg.format(min_pil_version))
num_bands = len(img.getbands())
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
return {"fillcolor": fill}
def _get_perspective_coeffs(startpoints, endpoints):
"""Helper function to get the coefficients (a, b, c, d, e, f, g, h) for the perspective transforms.
......@@ -450,7 +485,7 @@ def _get_perspective_coeffs(startpoints, endpoints):
return res.squeeze_(1).tolist()
def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC):
def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=None):
"""Perform perspective transform of the given PIL Image.
Args:
......@@ -458,14 +493,21 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC):
startpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the orignal image
endpoints: List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
interpolation: Default- Image.BICUBIC
fill (n-tuple or int or float): Pixel fill value for area outside the rotated
image. If int or float, the value is used for all bands respectively.
This option is only available for ``pillow>=5.0.0``.
Returns:
PIL Image: Perspectively transformed Image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = _parse_fill(fill, img, '5.0.0')
coeffs = _get_perspective_coeffs(startpoints, endpoints)
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation)
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
def vflip(img):
......@@ -721,30 +763,10 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
"""
def parse_fill(fill, num_bands):
if PILLOW_VERSION < "5.2.0":
if fill is None:
return {}
else:
msg = ("The option to fill background area of the rotated image, "
"requires pillow>=5.2.0")
raise RuntimeError(msg)
if fill is None:
fill = 0
if isinstance(fill, (int, float)) and num_bands > 1:
fill = tuple([fill] * num_bands)
if not isinstance(fill, (int, float)) and len(fill) != num_bands:
msg = ("The number of elements in 'fill' does not match the number of "
"bands of the image ({} != {})")
raise ValueError(msg.format(len(fill), num_bands))
return {"fillcolor": fill}
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
opts = parse_fill(fill, len(img.getbands()))
opts = _parse_fill(fill, img, '5.2.0')
return img.rotate(angle, resample, expand, center, **opts)
......
......@@ -550,12 +550,15 @@ class RandomPerspective(object):
distortion_scale(float): it controls the degree of distortion and ranges from 0 to 1. Default value is 0.5.
fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
If int, it is used for all channels respectively. Default value is 0.
"""
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC):
def __init__(self, distortion_scale=0.5, p=0.5, interpolation=Image.BICUBIC, fill=0):
self.p = p
self.interpolation = interpolation
self.distortion_scale = distortion_scale
self.fill = fill
def __call__(self, img):
"""
......@@ -571,7 +574,7 @@ class RandomPerspective(object):
if random.random() < self.p:
width, height = img.size
startpoints, endpoints = self.get_params(width, height, self.distortion_scale)
return F.perspective(img, startpoints, endpoints, self.interpolation)
return F.perspective(img, startpoints, endpoints, self.interpolation, self.fill)
return img
@staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment