Unverified Commit 883f1fb0 authored by Clement Joudet's avatar Clement Joudet Committed by GitHub
Browse files

Make ColorJitter torchscriptable (#2298)



* feat: torchscriptable adjusments

* fix: tensor output type

* feat: ColorJitter torchscriptable

* fix: too many blank lines

* fix: documentation spacing and torchscript annotation

* refactor: list type for _check_input

* refactor: reverting to original syntax
Co-authored-by: default avatarclement.joudet <clement.joudet@inventia.life>
parent 2cfc360e
...@@ -97,6 +97,23 @@ class Tester(unittest.TestCase): ...@@ -97,6 +97,23 @@ class Tester(unittest.TestCase):
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5) self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone)) self.assertTrue(torch.equal(img, img_clone))
# test for class interface
f = transforms.ColorJitter(brightness=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(contrast=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(saturation=factor.item())
scripted_fn = torch.jit.script(f)
scripted_fn(img)
f = transforms.ColorJitter(brightness=1)
scripted_fn = torch.jit.script(f)
scripted_fn(img)
def test_rgb_to_grayscale(self): def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale) script_rgb_to_grayscale = torch.jit.script(F_t.rgb_to_grayscale)
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
......
...@@ -39,6 +39,32 @@ class Tester(unittest.TestCase): ...@@ -39,6 +39,32 @@ class Tester(unittest.TestCase):
def test_random_vertical_flip(self): def test_random_vertical_flip(self):
self._test_flip('vflip', 'RandomVerticalFlip') self._test_flip('vflip', 'RandomVerticalFlip')
def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
for _ in range(20):
factor = 3 * torch.rand(1).item()
tensor, _ = self._create_data()
pil_img = T.ToPILImage()(tensor)
for func in fns:
adjusted_tensor = getattr(F, func)(tensor, factor)
adjusted_pil_img = getattr(F, func)(pil_img, factor)
adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img)
scripted_fn = torch.jit.script(getattr(F, func))
adjusted_tensor_script = scripted_fn(tensor, factor)
if not tensor.dtype.is_floating_point:
adjusted_tensor = adjusted_tensor.to(torch.float) / 255
adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255
# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max()
max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -633,67 +633,61 @@ def ten_crop(img, size, vertical_flip=False): ...@@ -633,67 +633,61 @@ def ten_crop(img, size, vertical_flip=False):
return first_five + second_five return first_five + second_five
def adjust_brightness(img, brightness_factor): def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""Adjust brightness of an Image. """Adjust brightness of an Image.
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL Image or Torch Tensor): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2. original image while 2 increases the brightness by a factor of 2.
Returns: Returns:
PIL Image: Brightness adjusted image. PIL Image or Torch Tensor: Brightness adjusted image.
""" """
if not _is_pil_image(img): if not isinstance(img, torch.Tensor):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.adjust_brightness(img, brightness_factor)
enhancer = ImageEnhance.Brightness(img) return F_t.adjust_brightness(img, brightness_factor)
img = enhancer.enhance(brightness_factor)
return img
def adjust_contrast(img, contrast_factor): def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
"""Adjust contrast of an Image. """Adjust contrast of an Image.
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL Image or Torch Tensor): Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2. original image while 2 increases the contrast by a factor of 2.
Returns: Returns:
PIL Image: Contrast adjusted image. PIL Image or Torch Tensor: Contrast adjusted image.
""" """
if not _is_pil_image(img): if not isinstance(img, torch.Tensor):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.adjust_contrast(img, contrast_factor)
enhancer = ImageEnhance.Contrast(img) return F_t.adjust_contrast(img, contrast_factor)
img = enhancer.enhance(contrast_factor)
return img
def adjust_saturation(img, saturation_factor): def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
"""Adjust color saturation of an image. """Adjust color saturation of an image.
Args: Args:
img (PIL Image): PIL Image to be adjusted. img (PIL Image or Torch Tensor): Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2. 2 will enhance the saturation by a factor of 2.
Returns: Returns:
PIL Image: Saturation adjusted image. PIL Image or Torch Tensor: Saturation adjusted image.
""" """
if not _is_pil_image(img): if not isinstance(img, torch.Tensor):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) return F_pil.adjust_saturation(img, saturation_factor)
enhancer = ImageEnhance.Color(img) return F_t.adjust_saturation(img, saturation_factor)
img = enhancer.enhance(saturation_factor)
return img
def adjust_hue(img, hue_factor): def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""Adjust hue of an image. """Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and The image hue is adjusted by converting the image to HSV and
...@@ -718,26 +712,10 @@ def adjust_hue(img, hue_factor): ...@@ -718,26 +712,10 @@ def adjust_hue(img, hue_factor):
Returns: Returns:
PIL Image: Hue adjusted image. PIL Image: Hue adjusted image.
""" """
if not(-0.5 <= hue_factor <= 0.5): if not isinstance(img, torch.Tensor):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor)) return F_pil.adjust_hue(img, hue_factor)
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img
h, s, v = img.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
img = Image.merge('HSV', (h, s, v)).convert(input_mode) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img
def adjust_gamma(img, gamma, gain=1): def adjust_gamma(img, gamma, gain=1):
......
...@@ -4,6 +4,7 @@ try: ...@@ -4,6 +4,7 @@ try:
except ImportError: except ImportError:
accimage = None accimage = None
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
import numpy as np
@torch.jit.unused @torch.jit.unused
...@@ -44,3 +45,110 @@ def vflip(img): ...@@ -44,3 +45,110 @@ def vflip(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return img.transpose(Image.FLIP_TOP_BOTTOM) return img.transpose(Image.FLIP_TOP_BOTTOM)
@torch.jit.unused
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an RGB image.
Args:
img (PIL Image): Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
@torch.jit.unused
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
@torch.jit.unused
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
@torch.jit.unused
def adjust_hue(img, hue_factor):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See `Hue`_ for more details.
.. _Hue: https://en.wikipedia.org/wiki/Hue
Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img
h, s, v = img.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
...@@ -865,7 +865,7 @@ class LinearTransformation(object): ...@@ -865,7 +865,7 @@ class LinearTransformation(object):
return format_string return format_string
class ColorJitter(object): class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast and saturation of an image. """Randomly change the brightness, contrast and saturation of an image.
Args: Args:
...@@ -882,20 +882,23 @@ class ColorJitter(object): ...@@ -882,20 +882,23 @@ class ColorJitter(object):
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
""" """
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
super().__init__()
self.brightness = self._check_input(brightness, 'brightness') self.brightness = self._check_input(brightness, 'brightness')
self.contrast = self._check_input(contrast, 'contrast') self.contrast = self._check_input(contrast, 'contrast')
self.saturation = self._check_input(saturation, 'saturation') self.saturation = self._check_input(saturation, 'saturation')
self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5), self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
clip_first_on_zero=False) clip_first_on_zero=False)
@torch.jit.unused
def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True): def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
if isinstance(value, numbers.Number): if isinstance(value, numbers.Number):
if value < 0: if value < 0:
raise ValueError("If {} is a single number, it must be non negative.".format(name)) raise ValueError("If {} is a single number, it must be non negative.".format(name))
value = [center - value, center + value] value = [center - float(value), center + float(value)]
if clip_first_on_zero: if clip_first_on_zero:
value[0] = max(value[0], 0) value[0] = max(value[0], 0.0)
elif isinstance(value, (tuple, list)) and len(value) == 2: elif isinstance(value, (tuple, list)) and len(value) == 2:
if not bound[0] <= value[0] <= value[1] <= bound[1]: if not bound[0] <= value[0] <= value[1] <= bound[1]:
raise ValueError("{} values should be between {}".format(name, bound)) raise ValueError("{} values should be between {}".format(name, bound))
...@@ -909,6 +912,7 @@ class ColorJitter(object): ...@@ -909,6 +912,7 @@ class ColorJitter(object):
return value return value
@staticmethod @staticmethod
@torch.jit.unused
def get_params(brightness, contrast, saturation, hue): def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image. """Get a randomized transform to be applied on image.
...@@ -941,17 +945,37 @@ class ColorJitter(object): ...@@ -941,17 +945,37 @@ class ColorJitter(object):
return transform return transform
def __call__(self, img): def forward(self, img):
""" """
Args: Args:
img (PIL Image): Input image. img (PIL Image or Tensor): Input image.
Returns: Returns:
PIL Image: Color jittered image. PIL Image or Tensor: Color jittered image.
""" """
transform = self.get_params(self.brightness, self.contrast, fn_idx = torch.randperm(4)
self.saturation, self.hue) for fn_id in fn_idx:
return transform(img) if fn_id == 0 and self.brightness is not None:
brightness = self.brightness
brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item()
img = F.adjust_brightness(img, brightness_factor)
if fn_id == 1 and self.contrast is not None:
contrast = self.contrast
contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item()
img = F.adjust_contrast(img, contrast_factor)
if fn_id == 2 and self.saturation is not None:
saturation = self.saturation
saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item()
img = F.adjust_saturation(img, saturation_factor)
if fn_id == 3 and self.hue is not None:
hue = self.hue
hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item()
img = F.adjust_hue(img, hue_factor)
return img
def __repr__(self): def __repr__(self):
format_string = self.__class__.__name__ + '(' format_string = self.__class__.__name__ + '('
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment