Commit 10f34160 authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Scriptability checks for Tensor Transforms (#1690)

* scriptability checks

* tests upds

* linter upds

* linter upds

* upds

* tuple list changes

* linter updates
parent 900c88c7
from __future__ import division
import torch
from torch import Tensor
import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional as F
import numpy as np
import unittest
import random
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
class Tester(unittest.TestCase):
def test_vflip(self):
script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16)
img_tensor_clone = img_tensor.clone()
vflipped_img = F_t.vflip(img_tensor)
......@@ -18,8 +21,12 @@ class Tester(unittest.TestCase):
self.assertEqual(vflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
def test_hflip(self):
script_hflip = torch.jit.script(F_t.hflip)
img_tensor = torch.randn(3, 16, 16)
img_tensor_clone = img_tensor.clone()
hflipped_img = F_t.hflip(img_tensor)
......@@ -27,8 +34,12 @@ class Tester(unittest.TestCase):
self.assertEqual(hflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
def test_crop(self):
script_crop = torch.jit.script(F_t.crop)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
top = random.randint(0, 15)
......@@ -42,11 +53,18 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
"functional_tensor crop not working")
# scriptable function test
cropped_img_script = script_crop(img_tensor, top, left, height, width)
self.assertTrue(torch.equal(img_cropped, cropped_img_script))
def test_adjustments(self):
fns = ((F.adjust_brightness, F_t.adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation))
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
for _ in range(20):
channels = 3
......@@ -60,11 +78,13 @@ class Tester(unittest.TestCase):
factor = 3 * torch.rand(1)
img_clone = img.clone()
for f, ft in fns:
for f, ft, sft in fns:
ft_img = ft(img, factor)
sft_img = sft(img, factor)
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255
img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
......@@ -73,10 +93,13 @@ class Tester(unittest.TestCase):
# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max()
max_diff_scripted = (sft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone))
def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
......@@ -84,8 +107,12 @@ class Tester(unittest.TestCase):
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
def test_center_crop(self):
script_center_crop = torch.jit.script(F_t.center_crop)
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
......@@ -93,8 +120,12 @@ class Tester(unittest.TestCase):
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
cropped_script = script_center_crop(img_tensor, [10, 10])
self.assertTrue(torch.equal(cropped_script, cropped_tensor))
def test_five_crop(self):
script_five_crop = torch.jit.script(F_t.five_crop)
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
......@@ -110,8 +141,13 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(cropped_tensor[4],
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
cropped_script = script_five_crop(img_tensor, [10, 10])
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
def test_ten_crop(self):
script_ten_crop = torch.jit.script(F_t.ten_crop)
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
......@@ -137,6 +173,10 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(cropped_tensor[9],
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
cropped_script = script_ten_crop(img_tensor, [10, 10])
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
if __name__ == '__main__':
......
from __future__ import division
import torch
import torchvision.transforms.functional as F
from torch import Tensor
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
def vflip(img_tensor):
def _is_tensor_a_torch_image(input):
return len(input.shape) == 3
def vflip(img):
# type: (Tensor) -> Tensor
"""Vertically flip the given the Image Tensor.
Args:
img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].
img (Tensor): Image Tensor to be flipped in the form [C, H, W].
Returns:
Tensor: Vertically flipped image Tensor.
"""
if not F._is_tensor_image(img_tensor):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
return img_tensor.flip(-2)
return img.flip(-2)
def hflip(img_tensor):
def hflip(img):
# type: (Tensor) -> Tensor
"""Horizontally flip the given the Image Tensor.
Args:
img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W].
img (Tensor): Image Tensor to be flipped in the form [C, H, W].
Returns:
Tensor: Horizontally flipped image Tensor.
"""
if not F._is_tensor_image(img_tensor):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
return img_tensor.flip(-1)
return img.flip(-1)
def crop(img, top, left, height, width):
# type: (Tensor, int, int, int, int) -> Tensor
"""Crop the given Image Tensor.
Args:
......@@ -45,13 +55,14 @@ def crop(img, top, left, height, width):
Returns:
Tensor: Cropped image.
"""
if not F._is_tensor_image(img):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
return img[..., top:top + height, left:left + width]
def rgb_to_grayscale(img):
# type: (Tensor) -> Tensor
"""Convert the given RGB Image Tensor to Grayscale.
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
is L = R * 0.2989 + G * 0.5870 + B * 0.1140
......@@ -70,6 +81,7 @@ def rgb_to_grayscale(img):
def adjust_brightness(img, brightness_factor):
# type: (Tensor, float) -> Tensor
"""Adjust brightness of an RGB image.
Args:
......@@ -81,13 +93,14 @@ def adjust_brightness(img, brightness_factor):
Returns:
Tensor: Brightness adjusted image.
"""
if not F._is_tensor_image(img):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
return _blend(img, 0, brightness_factor)
return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img, contrast_factor):
# type: (Tensor, float) -> Tensor
"""Adjust contrast of an RGB image.
Args:
......@@ -99,7 +112,7 @@ def adjust_contrast(img, contrast_factor):
Returns:
Tensor: Contrast adjusted image.
"""
if not F._is_tensor_image(img):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
......@@ -108,6 +121,7 @@ def adjust_contrast(img, contrast_factor):
def adjust_saturation(img, saturation_factor):
# type: (Tensor, float) -> Tensor
"""Adjust color saturation of an RGB image.
Args:
......@@ -119,13 +133,14 @@ def adjust_saturation(img, saturation_factor):
Returns:
Tensor: Saturation adjusted image.
"""
if not F._is_tensor_image(img):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
return _blend(img, rgb_to_grayscale(img), saturation_factor)
def center_crop(img, output_size):
# type: (Tensor, BroadcastingList2[int]) -> Tensor
"""Crop the Image Tensor and resize it to desired size.
Args:
......@@ -136,7 +151,7 @@ def center_crop(img, output_size):
Returns:
Tensor: Cropped image.
"""
if not F._is_tensor_image(img):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_, image_width, image_height = img.size()
......@@ -148,9 +163,10 @@ def center_crop(img, output_size):
def five_crop(img, size):
# type: (Tensor, BroadcastingList2[int]) -> List[Tensor]
"""Crop the given Image Tensor into four corners and the central crop.
.. Note::
This transform returns a tuple of Tensors and there may be a
This transform returns a List of Tensors and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
......@@ -159,10 +175,10 @@ def five_crop(img, size):
made.
Returns:
tuple: tuple (tl, tr, bl, br, center)
List: List (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
if not F._is_tensor_image(img):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
......@@ -179,14 +195,15 @@ def five_crop(img, size):
br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
center = center_crop(img, (crop_height, crop_width))
return (tl, tr, bl, br, center)
return [tl, tr, bl, br, center]
def ten_crop(img, size, vertical_flip=False):
# type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor]
"""Crop the given Image Tensor into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
.. Note::
This transform returns a tuple of images and there may be a
This transform returns a List of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.
Args:
......@@ -196,11 +213,11 @@ def ten_crop(img, size, vertical_flip=False):
vertical_flip (bool): Use vertical flipping instead of horizontal
Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
Corresponding top left, top right, bottom left, bottom right and center crop
and same for the flipped image's tensor.
"""
if not F._is_tensor_image(img):
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
......@@ -217,5 +234,6 @@ def ten_crop(img, size, vertical_flip=False):
def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255
# type: (Tensor, Tensor, float) -> Tensor
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment