Commit 10f34160 authored by Surgan Jandial's avatar Surgan Jandial Committed by Francisco Massa
Browse files

Scriptability checks for Tensor Transforms (#1690)

* scriptability checks

* tests upds

* linter upds

* linter upds

* upds

* tuple list changes

* linter updates
parent 900c88c7
from __future__ import division from __future__ import division
import torch import torch
from torch import Tensor
import torchvision.transforms as transforms import torchvision.transforms as transforms
import torchvision.transforms.functional_tensor as F_t import torchvision.transforms.functional_tensor as F_t
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
import numpy as np import numpy as np
import unittest import unittest
import random import random
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
class Tester(unittest.TestCase): class Tester(unittest.TestCase):
def test_vflip(self): def test_vflip(self):
script_vflip = torch.jit.script(F_t.vflip)
img_tensor = torch.randn(3, 16, 16) img_tensor = torch.randn(3, 16, 16)
img_tensor_clone = img_tensor.clone() img_tensor_clone = img_tensor.clone()
vflipped_img = F_t.vflip(img_tensor) vflipped_img = F_t.vflip(img_tensor)
...@@ -18,8 +21,12 @@ class Tester(unittest.TestCase): ...@@ -18,8 +21,12 @@ class Tester(unittest.TestCase):
self.assertEqual(vflipped_img.shape, img_tensor.shape) self.assertEqual(vflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, vflipped_img_again)) self.assertTrue(torch.equal(img_tensor, vflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
vflipped_img_script = script_vflip(img_tensor)
self.assertTrue(torch.equal(vflipped_img, vflipped_img_script))
def test_hflip(self): def test_hflip(self):
script_hflip = torch.jit.script(F_t.hflip)
img_tensor = torch.randn(3, 16, 16) img_tensor = torch.randn(3, 16, 16)
img_tensor_clone = img_tensor.clone() img_tensor_clone = img_tensor.clone()
hflipped_img = F_t.hflip(img_tensor) hflipped_img = F_t.hflip(img_tensor)
...@@ -27,8 +34,12 @@ class Tester(unittest.TestCase): ...@@ -27,8 +34,12 @@ class Tester(unittest.TestCase):
self.assertEqual(hflipped_img.shape, img_tensor.shape) self.assertEqual(hflipped_img.shape, img_tensor.shape)
self.assertTrue(torch.equal(img_tensor, hflipped_img_again)) self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
hflipped_img_script = script_hflip(img_tensor)
self.assertTrue(torch.equal(hflipped_img, hflipped_img_script))
def test_crop(self): def test_crop(self):
script_crop = torch.jit.script(F_t.crop)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone() img_tensor_clone = img_tensor.clone()
top = random.randint(0, 15) top = random.randint(0, 15)
...@@ -42,11 +53,18 @@ class Tester(unittest.TestCase): ...@@ -42,11 +53,18 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)), self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
"functional_tensor crop not working") "functional_tensor crop not working")
# scriptable function test
cropped_img_script = script_crop(img_tensor, top, left, height, width)
self.assertTrue(torch.equal(img_cropped, cropped_img_script))
def test_adjustments(self): def test_adjustments(self):
fns = ((F.adjust_brightness, F_t.adjust_brightness), script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
(F.adjust_contrast, F_t.adjust_contrast), script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
(F.adjust_saturation, F_t.adjust_saturation)) script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)
fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))
for _ in range(20): for _ in range(20):
channels = 3 channels = 3
...@@ -60,11 +78,13 @@ class Tester(unittest.TestCase): ...@@ -60,11 +78,13 @@ class Tester(unittest.TestCase):
factor = 3 * torch.rand(1) factor = 3 * torch.rand(1)
img_clone = img.clone() img_clone = img.clone()
for f, ft in fns: for f, ft, sft in fns:
ft_img = ft(img, factor) ft_img = ft(img, factor)
sft_img = sft(img, factor)
if not img.dtype.is_floating_point: if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255 ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255
img_pil = transforms.ToPILImage()(img) img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor) f_img_pil = f(img_pil, factor)
...@@ -73,10 +93,13 @@ class Tester(unittest.TestCase): ...@@ -73,10 +93,13 @@ class Tester(unittest.TestCase):
# F uses uint8 and F_t uses float, so there is a small # F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations. # difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max() max_diff = (ft_img - f_img).abs().max()
max_diff_scripted = (sft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5) self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone)) self.assertTrue(torch.equal(img, img_clone))
def test_rgb_to_grayscale(self): def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone() img_tensor_clone = img_tensor.clone()
grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int) grayscale_tensor = F_t.rgb_to_grayscale(img_tensor).to(int)
...@@ -84,8 +107,12 @@ class Tester(unittest.TestCase): ...@@ -84,8 +107,12 @@ class Tester(unittest.TestCase):
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max() max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001) self.assertLess(max_diff, 1.0001)
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
grayscale_script = script_rgb_to_grayscale(img_tensor).to(int)
self.assertTrue(torch.equal(grayscale_script, grayscale_tensor))
def test_center_crop(self): def test_center_crop(self):
script_center_crop = torch.jit.script(F_t.center_crop)
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone() img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.center_crop(img_tensor, [10, 10]) cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
...@@ -93,8 +120,12 @@ class Tester(unittest.TestCase): ...@@ -93,8 +120,12 @@ class Tester(unittest.TestCase):
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8) cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor)) self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
cropped_script = script_center_crop(img_tensor, [10, 10])
self.assertTrue(torch.equal(cropped_script, cropped_tensor))
def test_five_crop(self): def test_five_crop(self):
script_five_crop = torch.jit.script(F_t.five_crop)
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone() img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.five_crop(img_tensor, [10, 10]) cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
...@@ -110,8 +141,13 @@ class Tester(unittest.TestCase): ...@@ -110,8 +141,13 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(cropped_tensor[4], self.assertTrue(torch.equal(cropped_tensor[4],
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8))) (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
cropped_script = script_five_crop(img_tensor, [10, 10])
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
def test_ten_crop(self): def test_ten_crop(self):
script_ten_crop = torch.jit.script(F_t.ten_crop)
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
img_tensor_clone = img_tensor.clone() img_tensor_clone = img_tensor.clone()
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10]) cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
...@@ -137,6 +173,10 @@ class Tester(unittest.TestCase): ...@@ -137,6 +173,10 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.equal(cropped_tensor[9], self.assertTrue(torch.equal(cropped_tensor[9],
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8))) (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
# scriptable function test
cropped_script = script_ten_crop(img_tensor, [10, 10])
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
if __name__ == '__main__': if __name__ == '__main__':
......
from __future__ import division
import torch import torch
import torchvision.transforms.functional as F import torchvision.transforms.functional as F
from torch import Tensor
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
def vflip(img_tensor): def _is_tensor_a_torch_image(input):
return len(input.shape) == 3
def vflip(img):
# type: (Tensor) -> Tensor
"""Vertically flip the given the Image Tensor. """Vertically flip the given the Image Tensor.
Args: Args:
img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W]. img (Tensor): Image Tensor to be flipped in the form [C, H, W].
Returns: Returns:
Tensor: Vertically flipped image Tensor. Tensor: Vertically flipped image Tensor.
""" """
if not F._is_tensor_image(img_tensor): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
return img_tensor.flip(-2) return img.flip(-2)
def hflip(img_tensor): def hflip(img):
# type: (Tensor) -> Tensor
"""Horizontally flip the given the Image Tensor. """Horizontally flip the given the Image Tensor.
Args: Args:
img_tensor (Tensor): Image Tensor to be flipped in the form [C, H, W]. img (Tensor): Image Tensor to be flipped in the form [C, H, W].
Returns: Returns:
Tensor: Horizontally flipped image Tensor. Tensor: Horizontally flipped image Tensor.
""" """
if not F._is_tensor_image(img_tensor): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
return img_tensor.flip(-1) return img.flip(-1)
def crop(img, top, left, height, width): def crop(img, top, left, height, width):
# type: (Tensor, int, int, int, int) -> Tensor
"""Crop the given Image Tensor. """Crop the given Image Tensor.
Args: Args:
...@@ -45,13 +55,14 @@ def crop(img, top, left, height, width): ...@@ -45,13 +55,14 @@ def crop(img, top, left, height, width):
Returns: Returns:
Tensor: Cropped image. Tensor: Cropped image.
""" """
if not F._is_tensor_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
return img[..., top:top + height, left:left + width] return img[..., top:top + height, left:left + width]
def rgb_to_grayscale(img): def rgb_to_grayscale(img):
# type: (Tensor) -> Tensor
"""Convert the given RGB Image Tensor to Grayscale. """Convert the given RGB Image Tensor to Grayscale.
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
is L = R * 0.2989 + G * 0.5870 + B * 0.1140 is L = R * 0.2989 + G * 0.5870 + B * 0.1140
...@@ -70,6 +81,7 @@ def rgb_to_grayscale(img): ...@@ -70,6 +81,7 @@ def rgb_to_grayscale(img):
def adjust_brightness(img, brightness_factor): def adjust_brightness(img, brightness_factor):
# type: (Tensor, float) -> Tensor
"""Adjust brightness of an RGB image. """Adjust brightness of an RGB image.
Args: Args:
...@@ -81,13 +93,14 @@ def adjust_brightness(img, brightness_factor): ...@@ -81,13 +93,14 @@ def adjust_brightness(img, brightness_factor):
Returns: Returns:
Tensor: Brightness adjusted image. Tensor: Brightness adjusted image.
""" """
if not F._is_tensor_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
return _blend(img, 0, brightness_factor) return _blend(img, torch.zeros_like(img), brightness_factor)
def adjust_contrast(img, contrast_factor): def adjust_contrast(img, contrast_factor):
# type: (Tensor, float) -> Tensor
"""Adjust contrast of an RGB image. """Adjust contrast of an RGB image.
Args: Args:
...@@ -99,7 +112,7 @@ def adjust_contrast(img, contrast_factor): ...@@ -99,7 +112,7 @@ def adjust_contrast(img, contrast_factor):
Returns: Returns:
Tensor: Contrast adjusted image. Tensor: Contrast adjusted image.
""" """
if not F._is_tensor_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
mean = torch.mean(rgb_to_grayscale(img).to(torch.float)) mean = torch.mean(rgb_to_grayscale(img).to(torch.float))
...@@ -108,6 +121,7 @@ def adjust_contrast(img, contrast_factor): ...@@ -108,6 +121,7 @@ def adjust_contrast(img, contrast_factor):
def adjust_saturation(img, saturation_factor): def adjust_saturation(img, saturation_factor):
# type: (Tensor, float) -> Tensor
"""Adjust color saturation of an RGB image. """Adjust color saturation of an RGB image.
Args: Args:
...@@ -119,13 +133,14 @@ def adjust_saturation(img, saturation_factor): ...@@ -119,13 +133,14 @@ def adjust_saturation(img, saturation_factor):
Returns: Returns:
Tensor: Saturation adjusted image. Tensor: Saturation adjusted image.
""" """
if not F._is_tensor_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
return _blend(img, rgb_to_grayscale(img), saturation_factor) return _blend(img, rgb_to_grayscale(img), saturation_factor)
def center_crop(img, output_size): def center_crop(img, output_size):
# type: (Tensor, BroadcastingList2[int]) -> Tensor
"""Crop the Image Tensor and resize it to desired size. """Crop the Image Tensor and resize it to desired size.
Args: Args:
...@@ -136,7 +151,7 @@ def center_crop(img, output_size): ...@@ -136,7 +151,7 @@ def center_crop(img, output_size):
Returns: Returns:
Tensor: Cropped image. Tensor: Cropped image.
""" """
if not F._is_tensor_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
_, image_width, image_height = img.size() _, image_width, image_height = img.size()
...@@ -148,9 +163,10 @@ def center_crop(img, output_size): ...@@ -148,9 +163,10 @@ def center_crop(img, output_size):
def five_crop(img, size): def five_crop(img, size):
# type: (Tensor, BroadcastingList2[int]) -> List[Tensor]
"""Crop the given Image Tensor into four corners and the central crop. """Crop the given Image Tensor into four corners and the central crop.
.. Note:: .. Note::
This transform returns a tuple of Tensors and there may be a This transform returns a List of Tensors and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns. mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
...@@ -159,10 +175,10 @@ def five_crop(img, size): ...@@ -159,10 +175,10 @@ def five_crop(img, size):
made. made.
Returns: Returns:
tuple: tuple (tl, tr, bl, br, center) List: List (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop. Corresponding top left, top right, bottom left, bottom right and center crop.
""" """
if not F._is_tensor_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
...@@ -179,14 +195,15 @@ def five_crop(img, size): ...@@ -179,14 +195,15 @@ def five_crop(img, size):
br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height) br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
center = center_crop(img, (crop_height, crop_width)) center = center_crop(img, (crop_height, crop_width))
return (tl, tr, bl, br, center) return [tl, tr, bl, br, center]
def ten_crop(img, size, vertical_flip=False): def ten_crop(img, size, vertical_flip=False):
# type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor]
"""Crop the given Image Tensor into four corners and the central crop plus the """Crop the given Image Tensor into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default). flipped version of these (horizontal flipping is used by default).
.. Note:: .. Note::
This transform returns a tuple of images and there may be a This transform returns a List of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns. mismatch in the number of inputs and targets your ``Dataset`` returns.
Args: Args:
...@@ -196,11 +213,11 @@ def ten_crop(img, size, vertical_flip=False): ...@@ -196,11 +213,11 @@ def ten_crop(img, size, vertical_flip=False):
vertical_flip (bool): Use vertical flipping instead of horizontal vertical_flip (bool): Use vertical flipping instead of horizontal
Returns: Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
Corresponding top left, top right, bottom left, bottom right and center crop Corresponding top left, top right, bottom left, bottom right and center crop
and same for the flipped image's tensor. and same for the flipped image's tensor.
""" """
if not F._is_tensor_image(img): if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.') raise TypeError('tensor is not a torch image.')
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
...@@ -217,5 +234,6 @@ def ten_crop(img, size, vertical_flip=False): ...@@ -217,5 +234,6 @@ def ten_crop(img, size, vertical_flip=False):
def _blend(img1, img2, ratio): def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255 # type: (Tensor, Tensor, float) -> Tensor
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype) return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment