Unverified Commit c5257708 authored by Anthony Kantsemal's avatar Anthony Kantsemal Committed by GitHub
Browse files

RandomRotation and fill (#3303)



* initial fix

* fill=0

* docstrings

* fill type check

* fill type check

* set None to zero

* unit tests

* set instead of NotImplemented

* fix W293
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarVasilis Vryniotis <datumbox@users.noreply.github.com>
parent 1703e4ca
...@@ -180,6 +180,14 @@ class Tester(unittest.TestCase): ...@@ -180,6 +180,14 @@ class Tester(unittest.TestCase):
torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img))) torch.nn.functional.mse_loss(tr_img2, F.to_tensor(img)))
def test_randomperspective_fill(self): def test_randomperspective_fill(self):
# assert fill being either a Sequence or a Number
with self.assertRaises(TypeError):
transforms.RandomPerspective(fill={})
t = transforms.RandomPerspective(fill=None)
self.assertTrue(t.fill == 0)
height = 100 height = 100
width = 100 width = 100
img = torch.ones(3, height, width) img = torch.ones(3, height, width)
...@@ -1531,6 +1539,13 @@ class Tester(unittest.TestCase): ...@@ -1531,6 +1539,13 @@ class Tester(unittest.TestCase):
transforms.RandomRotation([-0.7]) transforms.RandomRotation([-0.7])
transforms.RandomRotation([-0.7, 0, 0.7]) transforms.RandomRotation([-0.7, 0, 0.7])
# assert fill being either a Sequence or a Number
with self.assertRaises(TypeError):
transforms.RandomRotation(0, fill={})
t = transforms.RandomRotation(0, fill=None)
self.assertTrue(t.fill == 0)
t = transforms.RandomRotation(10) t = transforms.RandomRotation(10)
angle = t.get_params(t.degrees) angle = t.get_params(t.degrees)
self.assertTrue(angle > -10 and angle < 10) self.assertTrue(angle > -10 and angle < 10)
...@@ -1573,6 +1588,13 @@ class Tester(unittest.TestCase): ...@@ -1573,6 +1588,13 @@ class Tester(unittest.TestCase):
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10]) transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10])
transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10]) transforms.RandomAffine([-90, 90], translate=[0.2, 0.2], scale=[0.5, 0.5], shear=[-10, 0, 10, 0, 10])
# assert fill being either a Sequence or a Number
with self.assertRaises(TypeError):
transforms.RandomAffine(0, fill={})
t = transforms.RandomAffine(0, fill=None)
self.assertTrue(t.fill == 0)
x = np.zeros((100, 100, 3), dtype=np.uint8) x = np.zeros((100, 100, 3), dtype=np.uint8)
img = F.to_pil_image(x) img = F.to_pil_image(x)
......
...@@ -673,8 +673,8 @@ class RandomPerspective(torch.nn.Module): ...@@ -673,8 +673,8 @@ class RandomPerspective(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed fill (sequence or number): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively. image. Default is ``0``. If given a number, the value is used for all bands respectively.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
""" """
...@@ -692,6 +692,12 @@ class RandomPerspective(torch.nn.Module): ...@@ -692,6 +692,12 @@ class RandomPerspective(torch.nn.Module):
self.interpolation = interpolation self.interpolation = interpolation
self.distortion_scale = distortion_scale self.distortion_scale = distortion_scale
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill self.fill = fill
def forward(self, img): def forward(self, img):
...@@ -1175,8 +1181,8 @@ class RandomRotation(torch.nn.Module): ...@@ -1175,8 +1181,8 @@ class RandomRotation(torch.nn.Module):
Note that the expand flag assumes rotation around the center and no translation. Note that the expand flag assumes rotation around the center and no translation.
center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner.
Default is the center of the image. Default is the center of the image.
fill (sequence or number, optional): Pixel fill value for the area outside the rotated fill (sequence or number): Pixel fill value for the area outside the rotated
image. If given a number, the value is used for all bands respectively. image. Default is ``0``. If given a number, the value is used for all bands respectively.
If input is PIL Image, the options is only available for ``Pillow>=5.2.0``. If input is PIL Image, the options is only available for ``Pillow>=5.2.0``.
resample (int, optional): deprecated argument and will be removed since v0.10.0. resample (int, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``interpolation`` parameter instead. Please use the ``interpolation`` parameter instead.
...@@ -1186,7 +1192,7 @@ class RandomRotation(torch.nn.Module): ...@@ -1186,7 +1192,7 @@ class RandomRotation(torch.nn.Module):
""" """
def __init__( def __init__(
self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=None, resample=None self, degrees, interpolation=InterpolationMode.NEAREST, expand=False, center=None, fill=0, resample=None
): ):
super().__init__() super().__init__()
if resample is not None: if resample is not None:
...@@ -1212,6 +1218,12 @@ class RandomRotation(torch.nn.Module): ...@@ -1212,6 +1218,12 @@ class RandomRotation(torch.nn.Module):
self.resample = self.interpolation = interpolation self.resample = self.interpolation = interpolation
self.expand = expand self.expand = expand
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fill = fill self.fill = fill
@staticmethod @staticmethod
...@@ -1280,8 +1292,8 @@ class RandomAffine(torch.nn.Module): ...@@ -1280,8 +1292,8 @@ class RandomAffine(torch.nn.Module):
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed fill (sequence or number): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively. image. Default is ``0``. If given a number, the value is used for all bands respectively.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``. If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0. fillcolor (sequence or number, optional): deprecated argument and will be removed since v0.10.0.
Please use the ``fill`` parameter instead. Please use the ``fill`` parameter instead.
...@@ -1339,6 +1351,12 @@ class RandomAffine(torch.nn.Module): ...@@ -1339,6 +1351,12 @@ class RandomAffine(torch.nn.Module):
self.shear = shear self.shear = shear
self.resample = self.interpolation = interpolation self.resample = self.interpolation = interpolation
if fill is None:
fill = 0
elif not isinstance(fill, (Sequence, numbers.Number)):
raise TypeError("Fill should be either a sequence or a number.")
self.fillcolor = self.fill = fill self.fillcolor = self.fill = fill
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment