Unverified Commit 7325e1d0 authored by Zhengyang Feng's avatar Zhengyang Feng Committed by GitHub
Browse files

Adjust adjust_* transforms (#3222)

* adjust_hue

* adjust_*

* colorjitter
parent 6315358d
...@@ -709,7 +709,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor: ...@@ -709,7 +709,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
brightness_factor (float): How much to adjust the brightness. Can be brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the any non negative number. 0 gives a black image, 1 gives the
...@@ -729,6 +729,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor: ...@@ -729,6 +729,8 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
contrast_factor (float): How much to adjust the contrast. Can be any contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2. original image while 2 increases the contrast by a factor of 2.
...@@ -747,6 +749,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: ...@@ -747,6 +749,8 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
saturation_factor (float): How much to adjust the saturation. 0 will saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2. 2 will enhance the saturation by a factor of 2.
...@@ -776,6 +780,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -776,6 +780,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
If img is torch Tensor, it is expected to be in [..., 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
hue_factor (float): How much to shift the hue channel. Should be in hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively. HSV space in positive and negative direction respectively.
...@@ -806,8 +813,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: ...@@ -806,8 +813,9 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
Args: Args:
img (PIL Image or Tensor): PIL Image to be adjusted. img (PIL Image or Tensor): PIL Image to be adjusted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, modes with transparency (alpha channel) are not supported.
gamma (float): Non negative real number, same as :math:`\gamma` in the equation. gamma (float): Non negative real number, same as :math:`\gamma` in the equation.
gamma larger than 1 make the shadows darker, gamma larger than 1 make the shadows darker,
while gamma smaller than 1 make dark regions lighter. while gamma smaller than 1 make dark regions lighter.
...@@ -1185,7 +1193,7 @@ def invert(img: Tensor) -> Tensor: ...@@ -1185,7 +1193,7 @@ def invert(img: Tensor) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to have its colors inverted. img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB". If img is PIL Image, it is expected to be in mode "L" or "RGB".
...@@ -1203,7 +1211,7 @@ def posterize(img: Tensor, bits: int) -> Tensor: ...@@ -1203,7 +1211,7 @@ def posterize(img: Tensor, bits: int) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to have its colors posterized. img (PIL Image or Tensor): Image to have its colors posterized.
If img is a Tensor, it should be of type torch.uint8 and If img is torch Tensor, it should be of type torch.uint8 and
it is expected to be in [..., 1 or 3, H, W] format, where ... means it is expected to be in [..., 1 or 3, H, W] format, where ... means
it can have an arbitrary number of leading dimensions. it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB". If img is PIL Image, it is expected to be in mode "L" or "RGB".
...@@ -1225,7 +1233,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor: ...@@ -1225,7 +1233,7 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to have its colors inverted. img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB". If img is PIL Image, it is expected to be in mode "L" or "RGB".
threshold (float): All pixels equal or above this value are inverted. threshold (float): All pixels equal or above this value are inverted.
...@@ -1243,7 +1251,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor: ...@@ -1243,7 +1251,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image to be adjusted. img (PIL Image or Tensor): Image to be adjusted.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
sharpness_factor (float): How much to adjust the sharpness. Can be sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the any non negative number. 0 gives a blurred image, 1 gives the
...@@ -1265,7 +1273,7 @@ def autocontrast(img: Tensor) -> Tensor: ...@@ -1265,7 +1273,7 @@ def autocontrast(img: Tensor) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image on which autocontrast is applied. img (PIL Image or Tensor): Image on which autocontrast is applied.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB". If img is PIL Image, it is expected to be in mode "L" or "RGB".
...@@ -1285,7 +1293,7 @@ def equalize(img: Tensor) -> Tensor: ...@@ -1285,7 +1293,7 @@ def equalize(img: Tensor) -> Tensor:
Args: Args:
img (PIL Image or Tensor): Image on which equalize is applied. img (PIL Image or Tensor): Image on which equalize is applied.
If img is a Tensor, it is expected to be in [..., 1 or 3, H, W] format, If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format,
where ... means it can have an arbitrary number of leading dimensions. where ... means it can have an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". If img is PIL Image, it is expected to be in mode "P", "L" or "RGB".
......
...@@ -180,7 +180,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: ...@@ -180,7 +180,9 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
_assert_image_tensor(img) _assert_image_tensor(img)
_assert_channels(img, [3]) _assert_channels(img, [1, 3])
if _get_image_num_channels(img) == 1: # Match PIL behaviour
return img
orig_dtype = img.dtype orig_dtype = img.dtype
if img.dtype == torch.uint8: if img.dtype == torch.uint8:
......
...@@ -1043,7 +1043,8 @@ class LinearTransformation(torch.nn.Module): ...@@ -1043,7 +1043,8 @@ class LinearTransformation(torch.nn.Module):
class ColorJitter(torch.nn.Module): class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast, saturation and hue of an image. """Randomly change the brightness, contrast, saturation and hue of an image.
If the image is torch Tensor, it is expected If the image is torch Tensor, it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. to have [..., 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, mode "1", "L", "I", "F" and modes with transparency (alpha channel) are not supported.
Args: Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness. brightness (float or tuple of float (min, max)): How much to jitter brightness.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment