Unverified Commit 11a39aaa authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Make RandomHorizontalFlip torchscriptable (#2282)

* Make RandomHorizontalFlip torchscriptable

* Make _is_tensor_a_torch_image more generic

* Make RandomVerticalFlip torchscriptable (#2283)

* Make RandomVerticalFlip torchscriptable

* Fix lint
parent de52437c
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import numpy as np
import unittest
class Tester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
return tensor, pil_img
def compareTensorToPIL(self, tensor, pil_image):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor))
def _test_flip(self, func, method):
tensor, pil_img = self._create_data()
flip_tensor = getattr(F, func)(tensor)
flip_pil_img = getattr(F, func)(pil_img)
self.compareTensorToPIL(flip_tensor, flip_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
flip_tensor_script = scripted_fn(tensor)
self.assertTrue(flip_tensor.equal(flip_tensor_script))
# test for class interface
f = getattr(T, method)()
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
def test_random_horizontal_flip(self):
self._test_flip('hflip', 'RandomHorizontalFlip')
def test_random_vertical_flip(self):
self._test_flip('vflip', 'RandomVerticalFlip')
if __name__ == '__main__':
unittest.main()
import torch import torch
from torch import Tensor
import math import math
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
try: try:
...@@ -11,6 +12,9 @@ import numbers ...@@ -11,6 +12,9 @@ import numbers
from collections.abc import Sequence, Iterable from collections.abc import Sequence, Iterable
import warnings import warnings
from . import functional_pil as F_pil
from . import functional_tensor as F_t
def _is_pil_image(img): def _is_pil_image(img):
if accimage is not None: if accimage is not None:
...@@ -434,19 +438,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE ...@@ -434,19 +438,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
return img return img
def hflip(img): def hflip(img: Tensor) -> Tensor:
"""Horizontally flip the given PIL Image. """Horizontally flip the given PIL Image or torch Tensor.
Args: Args:
img (PIL Image): Image to be flipped. img (PIL Image or Torch Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns: Returns:
PIL Image: Horizontally flipped image. PIL Image: Horizontally flipped image.
""" """
if not _is_pil_image(img): if not isinstance(img, torch.Tensor):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.hflip(img)
return img.transpose(Image.FLIP_LEFT_RIGHT) return F_t.hflip(img)
def _parse_fill(fill, img, min_pil_version): def _parse_fill(fill, img, min_pil_version):
...@@ -536,19 +543,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N ...@@ -536,19 +543,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts) return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
def vflip(img): def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given PIL Image. """Vertically flip the given PIL Image or torch Tensor.
Args: Args:
img (PIL Image): Image to be flipped. img (PIL Image or Torch Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns: Returns:
PIL Image: Vertically flipped image. PIL Image: Vertically flipped image.
""" """
if not _is_pil_image(img): if not isinstance(img, torch.Tensor):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.vflip(img)
return img.transpose(Image.FLIP_TOP_BOTTOM) return F_t.vflip(img)
def five_crop(img, size): def five_crop(img, size):
......
import torch
try:
import accimage
except ImportError:
accimage = None
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
@torch.jit.unused
def _is_pil_image(img):
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def hflip(img):
"""Horizontally flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Horizontally flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img):
"""Vertically flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM)
import torch import torch
import torchvision.transforms.functional as F
from torch import Tensor from torch import Tensor
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
def _is_tensor_a_torch_image(input): def _is_tensor_a_torch_image(input):
return len(input.shape) == 3 return input.ndim >= 2
def vflip(img): def vflip(img):
......
...@@ -500,25 +500,29 @@ class RandomCrop(object): ...@@ -500,25 +500,29 @@ class RandomCrop(object):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding) return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
class RandomHorizontalFlip(object): class RandomHorizontalFlip(torch.nn.Module):
"""Horizontally flip the given PIL Image randomly with a given probability. """Horizontally flip the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args: Args:
p (float): probability of the image being flipped. Default value is 0.5 p (float): probability of the image being flipped. Default value is 0.5
""" """
def __init__(self, p=0.5): def __init__(self, p=0.5):
super().__init__()
self.p = p self.p = p
def __call__(self, img): def forward(self, img):
""" """
Args: Args:
img (PIL Image): Image to be flipped. img (PIL Image or Tensor): Image to be flipped.
Returns: Returns:
PIL Image: Randomly flipped image. PIL Image or Tensor: Randomly flipped image.
""" """
if random.random() < self.p: if torch.rand(1) < self.p:
return F.hflip(img) return F.hflip(img)
return img return img
...@@ -526,25 +530,29 @@ class RandomHorizontalFlip(object): ...@@ -526,25 +530,29 @@ class RandomHorizontalFlip(object):
return self.__class__.__name__ + '(p={})'.format(self.p) return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomVerticalFlip(object): class RandomVerticalFlip(torch.nn.Module):
"""Vertically flip the given PIL Image randomly with a given probability. """Vertically flip the given PIL Image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args: Args:
p (float): probability of the image being flipped. Default value is 0.5 p (float): probability of the image being flipped. Default value is 0.5
""" """
def __init__(self, p=0.5): def __init__(self, p=0.5):
super().__init__()
self.p = p self.p = p
def __call__(self, img): def forward(self, img):
""" """
Args: Args:
img (PIL Image): Image to be flipped. img (PIL Image or Tensor): Image to be flipped.
Returns: Returns:
PIL Image: Randomly flipped image. PIL Image or Tensor: Randomly flipped image.
""" """
if random.random() < self.p: if torch.rand(1) < self.p:
return F.vflip(img) return F.vflip(img)
return img return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment