Unverified Commit 11a39aaa authored by Francisco Massa's avatar Francisco Massa Committed by GitHub
Browse files

Make RandomHorizontalFlip torchscriptable (#2282)

* Make RandomHorizontalFlip torchscriptable

* Make _is_tensor_a_torch_image more generic

* Make RandomVerticalFlip torchscriptable (#2283)

* Make RandomVerticalFlip torchscriptable

* Fix lint
parent de52437c
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image
import numpy as np
import unittest
class Tester(unittest.TestCase):
def _create_data(self, height=3, width=3, channels=3):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
return tensor, pil_img
def compareTensorToPIL(self, tensor, pil_image):
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
self.assertTrue(tensor.equal(pil_tensor))
def _test_flip(self, func, method):
tensor, pil_img = self._create_data()
flip_tensor = getattr(F, func)(tensor)
flip_pil_img = getattr(F, func)(pil_img)
self.compareTensorToPIL(flip_tensor, flip_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
flip_tensor_script = scripted_fn(tensor)
self.assertTrue(flip_tensor.equal(flip_tensor_script))
# test for class interface
f = getattr(T, method)()
scripted_fn = torch.jit.script(f)
scripted_fn(tensor)
def test_random_horizontal_flip(self):
self._test_flip('hflip', 'RandomHorizontalFlip')
def test_random_vertical_flip(self):
self._test_flip('vflip', 'RandomVerticalFlip')
if __name__ == '__main__':
unittest.main()
import torch
from torch import Tensor
import math
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
try:
......@@ -11,6 +12,9 @@ import numbers
from collections.abc import Sequence, Iterable
import warnings
from . import functional_pil as F_pil
from . import functional_tensor as F_t
def _is_pil_image(img):
if accimage is not None:
......@@ -434,19 +438,22 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
return img
def hflip(img):
"""Horizontally flip the given PIL Image.
def hflip(img: Tensor) -> Tensor:
"""Horizontally flip the given PIL Image or torch Tensor.
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Torch Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns:
PIL Image: Horizontally flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.hflip(img)
return img.transpose(Image.FLIP_LEFT_RIGHT)
return F_t.hflip(img)
def _parse_fill(fill, img, min_pil_version):
......@@ -536,19 +543,22 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
return img.transform(img.size, Image.PERSPECTIVE, coeffs, interpolation, **opts)
def vflip(img):
"""Vertically flip the given PIL Image.
def vflip(img: Tensor) -> Tensor:
"""Vertically flip the given PIL Image or torch Tensor.
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Torch Tensor): Image to be flipped. If img
is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not isinstance(img, torch.Tensor):
return F_pil.vflip(img)
return img.transpose(Image.FLIP_TOP_BOTTOM)
return F_t.vflip(img)
def five_crop(img, size):
......
import torch
try:
import accimage
except ImportError:
accimage = None
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
@torch.jit.unused
def _is_pil_image(img):
if accimage is not None:
return isinstance(img, (Image.Image, accimage.Image))
else:
return isinstance(img, Image.Image)
@torch.jit.unused
def hflip(img):
"""Horizontally flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Horizontally flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_LEFT_RIGHT)
@torch.jit.unused
def vflip(img):
"""Vertically flip the given PIL Image.
Args:
img (PIL Image): Image to be flipped.
Returns:
PIL Image: Vertically flipped image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM)
import torch
import torchvision.transforms.functional as F
from torch import Tensor
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
def _is_tensor_a_torch_image(input):
return len(input.shape) == 3
return input.ndim >= 2
def vflip(img):
......
......@@ -500,25 +500,29 @@ class RandomCrop(object):
return self.__class__.__name__ + '(size={0}, padding={1})'.format(self.size, self.padding)
class RandomHorizontalFlip(object):
"""Horizontally flip the given PIL Image randomly with a given probability.
class RandomHorizontalFlip(torch.nn.Module):
"""Horizontally flip the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
PIL Image or Tensor: Randomly flipped image.
"""
if random.random() < self.p:
if torch.rand(1) < self.p:
return F.hflip(img)
return img
......@@ -526,25 +530,29 @@ class RandomHorizontalFlip(object):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomVerticalFlip(object):
class RandomVerticalFlip(torch.nn.Module):
"""Vertically flip the given PIL Image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions
Args:
p (float): probability of the image being flipped. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def __call__(self, img):
def forward(self, img):
"""
Args:
img (PIL Image): Image to be flipped.
img (PIL Image or Tensor): Image to be flipped.
Returns:
PIL Image: Randomly flipped image.
PIL Image or Tensor: Randomly flipped image.
"""
if random.random() < self.p:
if torch.rand(1) < self.p:
return F.vflip(img)
return img
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment