Unverified Commit 83171d6a authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Implement all AutoAugment transforms + Policies (#3123)



* Invert Transform (#3104)

* Adding invert operator.

* Make use of the _assert_channels().

* Update upper bound value.

* Remove private doc from invert, create or reuse generic testing methods to avoid duplication of code in the tests. (#3106)

* Create posterize transformation and refactor common methods to assist reuse. (#3108)

* Implement the solarize transform. (#3112)

* Implement the adjust_sharpness transform (#3114)

* Adding functional operator for sharpness.

* Adding transforms for sharpness.

* Handling tiny images and adding a test.

* Implement the autocontrast transform. (#3117)

* Implement the equalize transform (#3119)

* Implement the equalize transform.

* Turn off deterministic for histogram.

* Fixing test. (#3126)

* Force ratio to be float to avoid numeric overflows on blend. (#3127)

* Separate the tests of Adjust Sharpness from ColorJitter. (#3128)

* Add AutoAugment Policies and main Transform (#3142)

* Separate the tests of Adjust Sharpness from ColorJitter.

* Initial implementation, not-jitable.

* AutoAugment passing JIT.

* Adding tests/docs, changing formatting.

* Update test.

* Fix formats

* Fix documentation and imports.

* Apply changes from code review:
- Move the transformations outside of AutoAugment on a separate method.
- Renamed degenerate method for sharpness for better clarity.

* Update torchvision/transforms/functional.py
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>

* Apply more changes from code review:
- Add InterpolationMode parameter.
- Move all declarations away from AutoAugment constructor and into the private method.

* Update documentation.

* Apply suggestions from code review
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>

* Apply changes from code review:
- Refactor code to eliminate as any to() and clamp() as possible.
- Reuse methods where possible.
- Apply speed ups.

* Replacing pad.
Co-authored-by: default avatarvfdev <vfdev.5@gmail.com>
Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 4eab7a67
...@@ -289,13 +289,14 @@ class Tester(TransformsTester): ...@@ -289,13 +289,14 @@ class Tester(TransformsTester):
self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs) self._test_fn_on_batch(batch_tensors, F.pad, padding=script_pad, **kwargs)
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max"): def _test_adjust_fn(self, fn, fn_pil, fn_t, configs, tol=2.0 + 1e-10, agg_method="max",
dts=(None, torch.float32, torch.float64)):
script_fn = torch.jit.script(fn) script_fn = torch.jit.script(fn)
torch.manual_seed(15) torch.manual_seed(15)
tensor, pil_img = self._create_data(26, 34, device=self.device) tensor, pil_img = self._create_data(26, 34, device=self.device)
batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device) batch_tensors = self._create_data_batch(16, 18, num_samples=4, device=self.device)
for dt in [None, torch.float32, torch.float64]: for dt in dts:
if dt is not None: if dt is not None:
tensor = F.convert_image_dtype(tensor, dt) tensor = F.convert_image_dtype(tensor, dt)
...@@ -862,6 +863,77 @@ class Tester(TransformsTester): ...@@ -862,6 +863,77 @@ class Tester(TransformsTester):
msg="{}, {}".format(ksize, sigma) msg="{}, {}".format(ksize, sigma)
) )
def test_invert(self):
self._test_adjust_fn(
F.invert,
F_pil.invert,
F_t.invert,
[{}],
tol=1.0,
agg_method="max"
)
def test_posterize(self):
self._test_adjust_fn(
F.posterize,
F_pil.posterize,
F_t.posterize,
[{"bits": bits} for bits in range(0, 8)],
tol=1.0,
agg_method="max",
dts=(None,)
)
def test_solarize(self):
self._test_adjust_fn(
F.solarize,
F_pil.solarize,
F_t.solarize,
[{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]],
tol=1.0,
agg_method="max",
dts=(None,)
)
self._test_adjust_fn(
F.solarize,
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
F_t.solarize,
[{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]],
tol=1.0,
agg_method="max",
dts=(torch.float32, torch.float64)
)
def test_adjust_sharpness(self):
self._test_adjust_fn(
F.adjust_sharpness,
F_pil.adjust_sharpness,
F_t.adjust_sharpness,
[{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)
def test_autocontrast(self):
self._test_adjust_fn(
F.autocontrast,
F_pil.autocontrast,
F_t.autocontrast,
[{}],
tol=1.0,
agg_method="max"
)
def test_equalize(self):
torch.set_deterministic(False)
self._test_adjust_fn(
F.equalize,
F_pil.equalize,
F_t.equalize,
[{}],
tol=1.0,
agg_method="max",
dts=(None,)
)
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
......
...@@ -1234,6 +1234,48 @@ class Tester(unittest.TestCase): ...@@ -1234,6 +1234,48 @@ class Tester(unittest.TestCase):
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape) y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans)) self.assertTrue(np.allclose(y_np, y_ans))
def test_adjust_sharpness(self):
x_shape = [4, 4, 3]
x_data = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 65, 108, 101, 120, 97, 110, 100, 101, 114, 32, 86, 114, 121, 110, 105,
111, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
# test 0
y_pil = F.adjust_sharpness(x_pil, 1)
y_np = np.array(y_pil)
self.assertTrue(np.allclose(y_np, x_np))
# test 1
y_pil = F.adjust_sharpness(x_pil, 0.5)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 30,
30, 74, 103, 96, 114, 97, 110, 100, 101, 114, 32, 81, 103, 108, 102, 101,
107, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))
# test 2
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil)
y_ans = [75, 121, 114, 105, 97, 107, 105, 32, 66, 111, 117, 114, 99, 104, 97, 0,
0, 46, 118, 111, 132, 97, 110, 100, 101, 114, 32, 95, 135, 146, 126, 112,
119, 116, 105, 115, 0, 0, 73, 32, 108, 111, 118, 101, 32, 121, 111, 117]
y_ans = np.array(y_ans, dtype=np.uint8).reshape(x_shape)
self.assertTrue(np.allclose(y_np, y_ans))
# test 3
x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape)
x_pil = Image.fromarray(x_np, mode='RGB')
x_th = torch.tensor(x_np.transpose(2, 0, 1))
y_pil = F.adjust_sharpness(x_pil, 2)
y_np = np.array(y_pil).transpose(2, 0, 1)
y_th = F.adjust_sharpness(x_th, 2)
self.assertTrue(np.allclose(y_np, y_th.numpy()))
def test_adjust_gamma(self): def test_adjust_gamma(self):
x_shape = [2, 2, 3] x_shape = [2, 2, 3]
x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
...@@ -1270,6 +1312,7 @@ class Tester(unittest.TestCase): ...@@ -1270,6 +1312,7 @@ class Tester(unittest.TestCase):
self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L') self.assertEqual(F.adjust_saturation(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L') self.assertEqual(F.adjust_contrast(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L') self.assertEqual(F.adjust_hue(x_l, 0.4).mode, 'L')
self.assertEqual(F.adjust_sharpness(x_l, 2).mode, 'L')
self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L') self.assertEqual(F.adjust_gamma(x_l, 0.5).mode, 'L')
def test_color_jitter(self): def test_color_jitter(self):
...@@ -1751,6 +1794,86 @@ class Tester(unittest.TestCase): ...@@ -1751,6 +1794,86 @@ class Tester(unittest.TestCase):
with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"): with self.assertRaisesRegex(ValueError, r"sigma should be a single number or a list/tuple with length 2"):
transforms.GaussianBlur(3, "sigma_string") transforms.GaussianBlur(3, "sigma_string")
def _test_randomness(self, fn, trans, configs):
random_state = random.getstate()
random.seed(42)
img = transforms.ToPILImage()(torch.rand(3, 16, 18))
for p in [0.5, 0.7]:
for config in configs:
inv_img = fn(img, **config)
num_samples = 250
counts = 0
for _ in range(num_samples):
tranformation = trans(p=p, **config)
tranformation.__repr__()
out = tranformation(img)
if out == inv_img:
counts += 1
p_value = stats.binom_test(counts, num_samples, p=p)
random.setstate(random_state)
self.assertGreater(p_value, 0.0001)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_invert(self):
self._test_randomness(
F.invert,
transforms.RandomInvert,
[{}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_posterize(self):
self._test_randomness(
F.posterize,
transforms.RandomPosterize,
[{"bits": 4}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_solarize(self):
self._test_randomness(
F.solarize,
transforms.RandomSolarize,
[{"threshold": 192}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_adjust_sharpness(self):
self._test_randomness(
F.adjust_sharpness,
transforms.RandomAdjustSharpness,
[{"sharpness_factor": 2.0}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_autocontrast(self):
self._test_randomness(
F.autocontrast,
transforms.RandomAutocontrast,
[{}]
)
@unittest.skipIf(stats is None, 'scipy.stats not available')
def test_random_equalize(self):
self._test_randomness(
F.equalize,
transforms.RandomEqualize,
[{}]
)
def test_autoaugment(self):
for policy in transforms.AutoAugmentPolicy:
for fill in [None, 85, (128, 128, 128)]:
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = transforms.AutoAugment(policy=policy, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -89,6 +89,34 @@ class Tester(TransformsTester): ...@@ -89,6 +89,34 @@ class Tester(TransformsTester):
def test_random_vertical_flip(self): def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip') self._test_op('vflip', 'RandomVerticalFlip')
def test_random_invert(self):
self._test_op('invert', 'RandomInvert')
def test_random_posterize(self):
fn_kwargs = meth_kwargs = {"bits": 4}
self._test_op(
'posterize', 'RandomPosterize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_random_solarize(self):
fn_kwargs = meth_kwargs = {"threshold": 192.0}
self._test_op(
'solarize', 'RandomSolarize', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_random_adjust_sharpness(self):
fn_kwargs = meth_kwargs = {"sharpness_factor": 2.0}
self._test_op(
'adjust_sharpness', 'RandomAdjustSharpness', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
)
def test_random_autocontrast(self):
self._test_op('autocontrast', 'RandomAutocontrast')
def test_random_equalize(self):
torch.set_deterministic(False)
self._test_op('equalize', 'RandomEqualize')
def test_color_jitter(self): def test_color_jitter(self):
tol = 1.0 + 1e-10 tol = 1.0 + 1e-10
...@@ -598,6 +626,22 @@ class Tester(TransformsTester): ...@@ -598,6 +626,22 @@ class Tester(TransformsTester):
with get_tmp_dir() as tmp_dir: with get_tmp_dir() as tmp_dir:
scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt")) scripted_fn.save(os.path.join(tmp_dir, "t_convert_dtype.pt"))
def test_autoaugment(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
for policy in T.AutoAugmentPolicy:
for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
for _ in range(100):
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)
self._test_transform_vs_scripted(transform, s_transform, tensor)
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
with get_tmp_dir() as tmp_dir:
s_transform.save(os.path.join(tmp_dir, "t_autoaugment.pt"))
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device") @unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
class CUDATester(Tester): class CUDATester(Tester):
......
from .transforms import * from .transforms import *
from .autoaugment import *
import math
import torch
from enum import Enum
from torch import Tensor
from torch.jit.annotations import List, Tuple
from typing import Optional
from . import functional as F, InterpolationMode
class AutoAugmentPolicy(Enum):
"""AutoAugment policies learned on different datasets.
"""
IMAGENET = "imagenet"
CIFAR10 = "cifar10"
SVHN = "svhn"
def _get_transforms(policy: AutoAugmentPolicy):
if policy == AutoAugmentPolicy.IMAGENET:
return [
(("Posterize", 0.4, 8), ("Rotate", 0.6, 9)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
(("Posterize", 0.6, 7), ("Posterize", 0.6, 6)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Equalize", 0.4, None), ("Rotate", 0.8, 8)),
(("Solarize", 0.6, 3), ("Equalize", 0.6, None)),
(("Posterize", 0.8, 5), ("Equalize", 1.0, None)),
(("Rotate", 0.2, 3), ("Solarize", 0.6, 8)),
(("Equalize", 0.6, None), ("Posterize", 0.4, 6)),
(("Rotate", 0.8, 8), ("Color", 0.4, 0)),
(("Rotate", 0.4, 9), ("Equalize", 0.6, None)),
(("Equalize", 0.0, None), ("Equalize", 0.8, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Rotate", 0.8, 8), ("Color", 1.0, 2)),
(("Color", 0.8, 8), ("Solarize", 0.8, 7)),
(("Sharpness", 0.4, 7), ("Invert", 0.6, None)),
(("ShearX", 0.6, 5), ("Equalize", 1.0, None)),
(("Color", 0.4, 0), ("Equalize", 0.6, None)),
(("Equalize", 0.4, None), ("Solarize", 0.2, 4)),
(("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)),
(("Invert", 0.6, None), ("Equalize", 1.0, None)),
(("Color", 0.6, 4), ("Contrast", 1.0, 8)),
(("Equalize", 0.8, None), ("Equalize", 0.6, None)),
]
elif policy == AutoAugmentPolicy.CIFAR10:
return [
(("Invert", 0.1, None), ("Contrast", 0.2, 6)),
(("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)),
(("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)),
(("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.5, None), ("Equalize", 0.9, None)),
(("ShearY", 0.2, 7), ("Posterize", 0.3, 7)),
(("Color", 0.4, 3), ("Brightness", 0.6, 7)),
(("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)),
(("Equalize", 0.6, None), ("Equalize", 0.5, None)),
(("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)),
(("Color", 0.7, 7), ("TranslateX", 0.5, 8)),
(("Equalize", 0.3, None), ("AutoContrast", 0.4, None)),
(("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)),
(("Brightness", 0.9, 6), ("Color", 0.2, 8)),
(("Solarize", 0.5, 2), ("Invert", 0.0, None)),
(("Equalize", 0.2, None), ("AutoContrast", 0.6, None)),
(("Equalize", 0.2, None), ("Equalize", 0.6, None)),
(("Color", 0.9, 9), ("Equalize", 0.6, None)),
(("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)),
(("Brightness", 0.1, 3), ("Color", 0.7, 0)),
(("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)),
(("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)),
(("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)),
(("Equalize", 0.8, None), ("Invert", 0.1, None)),
(("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)),
]
elif policy == AutoAugmentPolicy.SVHN:
return [
(("ShearX", 0.9, 4), ("Invert", 0.2, None)),
(("ShearY", 0.9, 8), ("Invert", 0.7, None)),
(("Equalize", 0.6, None), ("Solarize", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)),
(("ShearY", 0.9, 8), ("Invert", 0.4, None)),
(("ShearY", 0.9, 5), ("Solarize", 0.2, 6)),
(("Invert", 0.9, None), ("AutoContrast", 0.8, None)),
(("Equalize", 0.6, None), ("Rotate", 0.9, 3)),
(("ShearX", 0.9, 4), ("Solarize", 0.3, 3)),
(("ShearY", 0.8, 8), ("Invert", 0.7, None)),
(("Equalize", 0.9, None), ("TranslateY", 0.6, 6)),
(("Invert", 0.9, None), ("Equalize", 0.6, None)),
(("Contrast", 0.3, 3), ("Rotate", 0.8, 4)),
(("Invert", 0.8, None), ("TranslateY", 0.0, 2)),
(("ShearY", 0.7, 6), ("Solarize", 0.4, 8)),
(("Invert", 0.6, None), ("Rotate", 0.8, 4)),
(("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)),
(("ShearX", 0.1, 6), ("Invert", 0.6, None)),
(("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)),
(("ShearY", 0.8, 4), ("Invert", 0.8, None)),
(("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)),
(("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)),
(("ShearX", 0.7, 2), ("Invert", 0.1, None)),
]
def _get_magnitudes():
_BINS = 10
return {
# name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, _BINS), True),
"ShearY": (torch.linspace(0.0, 0.3, _BINS), True),
"TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True),
"Rotate": (torch.linspace(0.0, 30.0, _BINS), True),
"Brightness": (torch.linspace(0.0, 0.9, _BINS), True),
"Color": (torch.linspace(0.0, 0.9, _BINS), True),
"Contrast": (torch.linspace(0.0, 0.9, _BINS), True),
"Sharpness": (torch.linspace(0.0, 0.9, _BINS), True),
"Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False),
"Solarize": (torch.linspace(256.0, 0.0, _BINS), False),
"AutoContrast": (None, None),
"Equalize": (None, None),
"Invert": (None, None),
}
class AutoAugment(torch.nn.Module):
r"""AutoAugment data augmentation method based on
`"AutoAugment: Learning Augmentation Strategies from Data" <https://arxiv.org/pdf/1805.09501.pdf>`_.
The image can be a PIL Image or a Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
This option is supported for PIL image and Tensor inputs.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
Example:
>>> t = transforms.AutoAugment()
>>> transformed = t(image)
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> transforms.AutoAugment(),
>>> transforms.ToTensor()])
"""
def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET,
interpolation: InterpolationMode = InterpolationMode.NEAREST, fill: Optional[List[float]] = None):
super().__init__()
self.policy = policy
self.interpolation = interpolation
self.fill = fill
self.transforms = _get_transforms(policy)
if self.transforms is None:
raise ValueError("The provided policy {} is not recognized.".format(policy))
self._op_meta = _get_magnitudes()
@staticmethod
def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]:
"""Get parameters for autoaugment transformation
Returns:
params required by the autoaugment transformation
"""
policy_id = torch.randint(transform_num, (1,)).item()
probs = torch.rand((2,))
signs = torch.randint(2, (2,))
return policy_id, probs, signs
def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]:
return self._op_meta[name]
def forward(self, img: Tensor):
"""
img (PIL Image or Tensor): Image to be transformed.
Returns:
PIL Image or Tensor: AutoAugmented image.
"""
fill = self.fill
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * F._get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
transform_id, probs, signs = self.get_params(len(self.transforms))
for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]):
if probs[i] <= p:
magnitudes, signed = self._get_op_meta(op_name)
magnitude = float(magnitudes[magnitude_id].item()) \
if magnitudes is not None and magnitude_id is not None else 0.0
if signed is not None and signed and signs[i] == 0:
magnitude *= -1.0
if op_name == "ShearX":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0],
interpolation=self.interpolation, fill=fill)
elif op_name == "ShearY":
img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)],
interpolation=self.interpolation, fill=fill)
elif op_name == "TranslateX":
img = F.affine(img, angle=0.0, translate=[int(F._get_image_size(img)[0] * magnitude), 0], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "TranslateY":
img = F.affine(img, angle=0.0, translate=[0, int(F._get_image_size(img)[1] * magnitude)], scale=1.0,
interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill)
elif op_name == "Rotate":
img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill)
elif op_name == "Brightness":
img = F.adjust_brightness(img, 1.0 + magnitude)
elif op_name == "Color":
img = F.adjust_saturation(img, 1.0 + magnitude)
elif op_name == "Contrast":
img = F.adjust_contrast(img, 1.0 + magnitude)
elif op_name == "Sharpness":
img = F.adjust_sharpness(img, 1.0 + magnitude)
elif op_name == "Posterize":
img = F.posterize(img, int(magnitude))
elif op_name == "Solarize":
img = F.solarize(img, magnitude)
elif op_name == "AutoContrast":
img = F.autocontrast(img)
elif op_name == "Equalize":
img = F.equalize(img)
elif op_name == "Invert":
img = F.invert(img)
else:
raise ValueError("The provided operator {} is not recognized.".format(op_name))
return img
def __repr__(self):
return self.__class__.__name__ + '(policy={}, fill={})'.format(self.policy, self.fill)
...@@ -1173,3 +1173,118 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa ...@@ -1173,3 +1173,118 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[floa
if not isinstance(img, torch.Tensor): if not isinstance(img, torch.Tensor):
output = to_pil_image(output) output = to_pil_image(output)
return output return output
def invert(img: Tensor) -> Tensor:
"""Invert the colors of an RGB/grayscale PIL Image or torch Tensor.
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns:
PIL Image or Tensor: Color inverted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.invert(img)
return F_t.invert(img)
def posterize(img: Tensor, bits: int) -> Tensor:
"""Posterize a PIL Image or torch Tensor by reducing the number of bits for each color channel.
Args:
img (PIL Image or Tensor): Image to have its colors posterized.
If img is a Tensor, it should be of type torch.uint8 and
it is expected to be in [..., H, W] format, where ... means
it can have an arbitrary number of trailing dimensions.
bits (int): The number of bits to keep for each channel (0-8).
Returns:
PIL Image or Tensor: Posterized image.
"""
if not (0 <= bits <= 8):
raise ValueError('The number if bits should be between 0 and 8. Got {}'.format(bits))
if not isinstance(img, torch.Tensor):
return F_pil.posterize(img, bits)
return F_t.posterize(img, bits)
def solarize(img: Tensor, threshold: float) -> Tensor:
"""Solarize a PIL Image or torch Tensor by inverting all pixel values above a threshold.
Args:
img (PIL Image or Tensor): Image to have its colors inverted.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
threshold (float): All pixels equal or above this value are inverted.
Returns:
PIL Image or Tensor: Solarized image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.solarize(img, threshold)
return F_t.solarize(img, threshold)
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
"""Adjust the sharpness of an Image.
Args:
img (PIL Image or Tensor): Image to be adjusted.
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
Returns:
PIL Image or Tensor: Sharpness adjusted image.
"""
if not isinstance(img, torch.Tensor):
return F_pil.adjust_sharpness(img, sharpness_factor)
return F_t.adjust_sharpness(img, sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
"""Maximize contrast of a PIL Image or torch Tensor by remapping its
pixels per channel so that the lowest becomes black and the lightest
becomes white.
Args:
img (PIL Image or Tensor): Image on which autocontrast is applied.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns:
PIL Image or Tensor: An image that was autocontrasted.
"""
if not isinstance(img, torch.Tensor):
return F_pil.autocontrast(img)
return F_t.autocontrast(img)
def equalize(img: Tensor) -> Tensor:
"""Equalize the histogram of a PIL Image or torch Tensor by applying
a non-linear mapping to the input in order to create a uniform
distribution of grayscale values in the output.
Args:
img (PIL Image or Tensor): Image on which equalize is applied.
If img is a Tensor, it is expected to be in [..., H, W] format,
where ... means it can have an arbitrary number of trailing
dimensions.
Returns:
PIL Image or Tensor: An image that was equalized.
"""
if not isinstance(img, torch.Tensor):
return F_pil.equalize(img)
return F_t.equalize(img)
...@@ -606,3 +606,48 @@ def to_grayscale(img, num_output_channels): ...@@ -606,3 +606,48 @@ def to_grayscale(img, num_output_channels):
raise ValueError('num_output_channels should be either 1 or 3') raise ValueError('num_output_channels should be either 1 or 3')
return img return img
@torch.jit.unused
def invert(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img)
@torch.jit.unused
def posterize(img, bits):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits)
@torch.jit.unused
def solarize(img, threshold):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)
@torch.jit.unused
def adjust_sharpness(img, sharpness_factor):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
enhancer = ImageEnhance.Sharpness(img)
img = enhancer.enhance(sharpness_factor)
return img
@torch.jit.unused
def autocontrast(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)
@torch.jit.unused
def equalize(img):
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)
...@@ -570,6 +570,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa ...@@ -570,6 +570,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0 bound = 1.0 if img1.is_floating_point() else 255.0
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
...@@ -1180,3 +1181,133 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te ...@@ -1180,3 +1181,133 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype) img = _cast_squeeze_out(img, need_cast, need_squeeze, out_dtype)
return img return img
def invert(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
_assert_channels(img, [1, 3])
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
return bound - img
def posterize(img: Tensor, bits: int) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
if img.dtype != torch.uint8:
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))
_assert_channels(img, [1, 3])
mask = -int(2**(8 - bits)) # JIT-friendly for: ~(2 ** (8 - bits) - 1)
return img & mask
def solarize(img: Tensor, threshold: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
_assert_channels(img, [1, 3])
inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
def _blurred_degenerate_image(img: Tensor) -> Tensor:
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = torch.ones((3, 3), dtype=dtype, device=img.device)
kernel[1, 1] = 5.0
kernel /= kernel.sum()
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
result_tmp, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ])
result_tmp = conv2d(result_tmp, kernel, groups=result_tmp.shape[-3])
result_tmp = _cast_squeeze_out(result_tmp, need_cast, need_squeeze, out_dtype)
result = img.clone()
result[..., 1:-1, 1:-1] = result_tmp
return result
def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_channels(img, [1, 3])
if img.size(-1) <= 2 or img.size(-2) <= 2:
return img
return _blend(img, _blurred_degenerate_image(img), sharpness_factor)
def autocontrast(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
_assert_channels(img, [1, 3])
bound = 1.0 if img.is_floating_point() else 255.0
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
maximum = img.amax(dim=(-2, -1), keepdim=True).to(dtype)
eq_idxs = torch.where(minimum == maximum)[0]
minimum[eq_idxs] = 0
maximum[eq_idxs] = bound
scale = bound / (maximum - minimum)
return ((img - minimum) * scale).clamp(0, bound).to(img.dtype)
def _scale_channel(img_chan):
hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255)
nonzero_hist = hist[hist != 0]
step = nonzero_hist[:-1].sum() // 255
if step == 0:
return img_chan
lut = (torch.cumsum(hist, 0) + (step // 2)) // step
lut = torch.nn.functional.pad(lut, [1, 0])[:-1].clamp(0, 255)
return lut[img_chan.to(torch.int64)].to(torch.uint8)
def _equalize_single_image(img: Tensor) -> Tensor:
return torch.stack([_scale_channel(img[c]) for c in range(img.size(0))])
def equalize(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
if not (3 <= img.ndim <= 4):
raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim))
if img.dtype != torch.uint8:
raise TypeError("Only torch.uint8 image tensors are supported, but found {}".format(img.dtype))
_assert_channels(img, [1, 3])
if img.ndim == 3:
return _equalize_single_image(img)
return torch.stack([_equalize_single_image(x) for x in img])
...@@ -21,7 +21,8 @@ __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImag ...@@ -21,7 +21,8 @@ __all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImag
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
"LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale",
"RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode"] "RandomPerspective", "RandomErasing", "GaussianBlur", "InterpolationMode", "RandomInvert", "RandomPosterize",
"RandomSolarize", "RandomAdjustSharpness", "RandomAutocontrast", "RandomEqualize"]
class Compose: class Compose:
...@@ -1038,7 +1039,7 @@ class LinearTransformation(torch.nn.Module): ...@@ -1038,7 +1039,7 @@ class LinearTransformation(torch.nn.Module):
class ColorJitter(torch.nn.Module): class ColorJitter(torch.nn.Module):
"""Randomly change the brightness, contrast and saturation of an image. """Randomly change the brightness, contrast, saturation and hue of an image.
Args: Args:
brightness (float or tuple of float (min, max)): How much to jitter brightness. brightness (float or tuple of float (min, max)): How much to jitter brightness.
...@@ -1699,3 +1700,190 @@ def _setup_angle(x, name, req_sizes=(2, )): ...@@ -1699,3 +1700,190 @@ def _setup_angle(x, name, req_sizes=(2, )):
_check_sequence_input(x, name, req_sizes) _check_sequence_input(x, name, req_sizes)
return [float(d) for d in x] return [float(d) for d in x]
class RandomInvert(torch.nn.Module):
"""Inverts the colors of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions.
Args:
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be inverted.
Returns:
PIL Image or Tensor: Randomly color inverted image.
"""
if torch.rand(1).item() < self.p:
return F.invert(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomPosterize(torch.nn.Module):
"""Posterize the image randomly with a given probability by reducing the
number of bits for each color channel. The image can be a PIL Image or a torch
Tensor, in which case it is expected to have [..., H, W] shape, where ... means
an arbitrary number of leading dimensions.
Args:
bits (int): number of bits to keep for each channel (0-8)
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, bits, p=0.5):
super().__init__()
self.bits = bits
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be posterized.
Returns:
PIL Image or Tensor: Randomly posterized image.
"""
if torch.rand(1).item() < self.p:
return F.posterize(img, self.bits)
return img
def __repr__(self):
return self.__class__.__name__ + '(bits={},p={})'.format(self.bits, self.p)
class RandomSolarize(torch.nn.Module):
"""Solarize the image randomly with a given probability by inverting all pixel
values above a threshold. The image can be a PIL Image or a torch Tensor, in
which case it is expected to have [..., H, W] shape, where ... means an arbitrary
number of leading dimensions.
Args:
threshold (float): all pixels equal or above this value are inverted.
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, threshold, p=0.5):
super().__init__()
self.threshold = threshold
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be solarized.
Returns:
PIL Image or Tensor: Randomly solarized image.
"""
if torch.rand(1).item() < self.p:
return F.solarize(img, self.threshold)
return img
def __repr__(self):
return self.__class__.__name__ + '(threshold={},p={})'.format(self.threshold, self.p)
class RandomAdjustSharpness(torch.nn.Module):
"""Adjust the sharpness of the image randomly with a given probability. The image
can be a PIL Image or a torch Tensor, in which case it is expected to have [..., H, W]
shape, where ... means an arbitrary number of leading dimensions.
Args:
sharpness_factor (float): How much to adjust the sharpness. Can be
any non negative number. 0 gives a blurred image, 1 gives the
original image while 2 increases the sharpness by a factor of 2.
p (float): probability of the image being color inverted. Default value is 0.5
"""
def __init__(self, sharpness_factor, p=0.5):
super().__init__()
self.sharpness_factor = sharpness_factor
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be sharpened.
Returns:
PIL Image or Tensor: Randomly sharpened image.
"""
if torch.rand(1).item() < self.p:
return F.adjust_sharpness(img, self.sharpness_factor)
return img
def __repr__(self):
return self.__class__.__name__ + '(sharpness_factor={},p={})'.format(self.sharpness_factor, self.p)
class RandomAutocontrast(torch.nn.Module):
"""Autocontrast the pixels of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions.
Args:
p (float): probability of the image being autocontrasted. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be autocontrasted.
Returns:
PIL Image or Tensor: Randomly autocontrasted image.
"""
if torch.rand(1).item() < self.p:
return F.autocontrast(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
class RandomEqualize(torch.nn.Module):
"""Equalize the histogram of the given image randomly with a given probability.
The image can be a PIL Image or a torch Tensor, in which case it is expected
to have [..., H, W] shape, where ... means an arbitrary number of leading
dimensions.
Args:
p (float): probability of the image being equalized. Default value is 0.5
"""
def __init__(self, p=0.5):
super().__init__()
self.p = p
def forward(self, img):
"""
Args:
img (PIL Image or Tensor): Image to be equalized.
Returns:
PIL Image or Tensor: Randomly equalized image.
"""
if torch.rand(1).item() < self.p:
return F.equalize(img)
return img
def __repr__(self):
return self.__class__.__name__ + '(p={})'.format(self.p)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment