Commit 618072bf authored by Sasank Chilamkurthy's avatar Sasank Chilamkurthy Committed by Francisco Massa
Browse files

Add Color transforms (#275)

* Add adjust_hue and adjust_saturation

* Add adjust_brightness, adjust_contrast

Also
* Change adjust_saturation to use pillow implementation
* Documentation made clear

* Add adjust_gamma

* Add ColorJitter

* Address review comments

* Fix documentation for ColorJitter

* Address review comments 2

* Fallback to adjust_hue in case of BW images

* Add tests

* fix dtype
parent 88e81cea
......@@ -422,6 +422,167 @@ class Tester(unittest.TestCase):
p_value = stats.binom_test(num_horizontal, 100, p=0.5)
assert p_value > 0.05
def test_adjust_brightness(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
# test 0
y_pil = transforms.adjust_brightness(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)
# test 1
y_pil = transforms.adjust_brightness(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [0, 2, 6, 27, 67, 113, 18, 4, 117, 45, 127, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
# test 2
y_pil = transforms.adjust_brightness(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 10, 26, 108, 255, 255, 74, 16, 255, 180, 255, 2]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
def test_adjust_contrast(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
# test 0
y_pil = transforms.adjust_contrast(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)
# test 1
y_pil = transforms.adjust_contrast(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [43, 45, 49, 70, 110, 156, 61, 47, 160, 88, 170, 43]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
# test 2
y_pil = transforms.adjust_contrast(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 0, 0, 22, 184, 255, 0, 0, 255, 94, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
def test_adjust_saturation(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
# test 0
y_pil = transforms.adjust_saturation(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)
# test 1
y_pil = transforms.adjust_saturation(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [2, 4, 8, 87, 128, 173, 39, 25, 138, 133, 215, 88]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
# test 2
y_pil = transforms.adjust_saturation(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 6, 22, 0, 149, 255, 32, 0, 255, 4, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
def test_adjust_hue(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
with self.assertRaises(ValueError):
transforms.adjust_hue(x_pil, -0.7)
transforms.adjust_hue(x_pil, 1)
# test 0: almost same as x_data but not exact.
# probably because hsv <-> rgb floating point ops
y_pil = transforms.adjust_hue(x_pil, 0)
y_np = np.array(y_pil)
y_ans = [0, 5, 13, 54, 139, 226, 35, 8, 234, 91, 255, 1]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
# test 1
y_pil = transforms.adjust_hue(x_pil, 0.25)
y_np = np.array(y_pil)
y_ans = [13, 0, 12, 224, 54, 226, 234, 8, 99, 1, 222, 255]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
# test 2
y_pil = transforms.adjust_hue(x_pil, -0.25)
y_np = np.array(y_pil)
y_ans = [0, 13, 2, 54, 226, 58, 8, 234, 152, 255, 43, 1]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
def test_adjust_gamma(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
# test 0
y_pil = transforms.adjust_gamma(x_pil, 1)
y_np = np.array(y_pil)
assert np.allclose(y_np, x_np)
# test 1
y_pil = transforms.adjust_gamma(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [0, 35, 57, 117, 185, 240, 97, 45, 244, 151, 255, 15]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
# test 2
y_pil = transforms.adjust_gamma(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [0, 0, 0, 11, 71, 200, 5, 0, 214, 31, 255, 0]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
assert np.allclose(y_np, y_ans)
def test_adjusts_L_mode(self):
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_rgb = Image.fromarray(x_np, mode='RGB')
x_l = x_rgb.convert('L')
assert transforms.adjust_brightness(x_l, 2).mode == 'L'
assert transforms.adjust_saturation(x_l, 2).mode == 'L'
assert transforms.adjust_contrast(x_l, 2).mode == 'L'
assert transforms.adjust_hue(x_l, 0.4).mode == 'L'
assert transforms.adjust_gamma(x_l, 0.5).mode == 'L'
def test_color_jitter(self):
color_jitter = transforms.ColorJitter(2, 2, 2, 0.1)
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_pil_2 = x_pil.convert('L')
for i in range(10):
y_pil = color_jitter(x_pil)
assert y_pil.mode == x_pil.mode
y_pil_2 = color_jitter(x_pil_2)
assert y_pil_2.mode == x_pil_2.mode
if __name__ == '__main__':
unittest.main()
......@@ -2,7 +2,7 @@ from __future__ import division
import torch
import math
import random
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageEnhance
try:
import accimage
except ImportError:
......@@ -355,6 +355,145 @@ def ten_crop(img, size, vertical_flip=False):
return first_five + second_five
def adjust_brightness(img, brightness_factor):
"""Adjust brightness of an Image.
Args:
img (PIL.Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL.Image: Brightness adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(brightness_factor)
return img
def adjust_contrast(img, contrast_factor):
"""Adjust contrast of an Image.
Args:
img (PIL.Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL.Image: Contrast adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(contrast_factor)
return img
def adjust_saturation(img, saturation_factor):
"""Adjust color saturation of an image.
Args:
img (PIL.Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL.Image: Saturation adjusted image.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Color(img)
img = enhancer.enhance(saturation_factor)
return img
def adjust_hue(img, hue_factor):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (PIL.Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL.Image: Hue adjusted image.
"""
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
input_mode = img.mode
if input_mode in {'L', '1', 'I', 'F'}:
return img
h, s, v = img.convert('HSV').split()
np_h = np.array(h, dtype=np.uint8)
# uint8 addition take cares of rotation across boundaries
with np.errstate(over='ignore'):
np_h += np.uint8(hue_factor * 255)
h = Image.fromarray(np_h, 'L')
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
return img
def adjust_gamma(img, gamma, gain=1):
"""Perform gamma correction on an image.
Also known as Power Law Transform. Intensities in RGB mode are adjusted
based on the following equation:
I_out = 255 * gain * ((I_in / 255) ** gamma)
See https://en.wikipedia.org/wiki/Gamma_correction for more details.
Args:
img (PIL.Image): PIL Image to be adjusted.
gamma (float): Non negative real number. gamma larger than 1 make the
shadows darker, while gamma smaller than 1 make dark regions
lighter.
gain (float): The constant multiplier.
"""
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')
input_mode = img.mode
img = img.convert('RGB')
np_img = np.array(img, dtype=np.float32)
np_img = 255 * gain * ((np_img / 255) ** gamma)
np_img = np.uint8(np.clip(np_img, 0, 255))
img = Image.fromarray(np_img, 'RGB').convert(input_mode)
return img
class Compose(object):
"""Composes several transforms together.
......@@ -777,3 +916,67 @@ class TenCrop(object):
def __call__(self, img):
return ten_crop(img, self.size, self.vertical_flip)
class ColorJitter(object):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
self.brightness = brightness
self.contrast = contrast
self.saturation = saturation
self.hue = hue
@staticmethod
def get_params(brightness, contrast, saturation, hue):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms = []
if brightness > 0:
brightness_factor = np.random.uniform(max(0, 1 - brightness), 1 + brightness)
transforms.append(Lambda(lambda img: adjust_brightness(img, brightness_factor)))
if contrast > 0:
contrast_factor = np.random.uniform(max(0, 1 - contrast), 1 + contrast)
transforms.append(Lambda(lambda img: adjust_contrast(img, contrast_factor)))
if saturation > 0:
saturation_factor = np.random.uniform(max(0, 1 - saturation), 1 + saturation)
transforms.append(Lambda(lambda img: adjust_saturation(img, saturation_factor)))
if hue > 0:
hue_factor = np.random.uniform(-hue, hue)
transforms.append(Lambda(lambda img: adjust_hue(img, hue_factor)))
np.random.shuffle(transforms)
transform = Compose(transforms)
return transform
def __call__(self, img):
"""
Args:
img (PIL.Image): Input image.
Returns:
PIL.Image: Color jittered image.
"""
transform = self.get_params(self.brightness, self.contrast,
self.saturation, self.hue)
return transform(img)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment