Commit c1097033 authored by Hye Soo Yang's avatar Hye Soo Yang Committed by A. Unique TensorFlower
Browse files

[keras_cv] Add more AutoAugment policies (svhn, reduced imagenet, reduced cifar10).

PiperOrigin-RevId: 347020444
parent dc2eebf8
...@@ -739,7 +739,8 @@ class AutoAugment(ImageAugment): ...@@ -739,7 +739,8 @@ class AutoAugment(ImageAugment):
Args: Args:
augmentation_name: The name of the AutoAugment policy to use. The augmentation_name: The name of the AutoAugment policy to use. The
available options are `v0` and `test`. `v0` is the policy used for all available options are `v0`, `test`, `reduced_cifar10`, `svhn` and
`reduced_imagenet`. `v0` is the policy used for all
of the results in the paper and was found to achieve the best results on of the results in the paper and was found to achieve the best results on
the COCO dataset. `v1`, `v2` and `v3` are additional good policies found the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
on the COCO dataset that have slight variation in what operations were on the COCO dataset that have slight variation in what operations were
...@@ -759,6 +760,9 @@ class AutoAugment(ImageAugment): ...@@ -759,6 +760,9 @@ class AutoAugment(ImageAugment):
'v0': self.policy_v0(), 'v0': self.policy_v0(),
'test': self.policy_test(), 'test': self.policy_test(),
'simple': self.policy_simple(), 'simple': self.policy_simple(),
'reduced_cifar10': self.policy_reduced_cifar10(),
'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(),
} }
if augmentation_name not in self.available_policies: if augmentation_name not in self.available_policies:
...@@ -868,6 +872,132 @@ class AutoAugment(ImageAugment): ...@@ -868,6 +872,132 @@ class AutoAugment(ImageAugment):
] ]
return policy return policy
@staticmethod
def policy_reduced_cifar10():
"""Autoaugment policy for reduced CIFAR-10 dataset.
Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy = [
[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
[('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
[('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
[('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)],
[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
[('Color', 0.4, 3), ('Brightness', 0.6, 7)],
[('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
[('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
[('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)],
[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
[('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
[('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
[('Brightness', 0.9, 6), ('Color', 0.2, 8)],
[('Solarize', 0.5, 2), ('Invert', 0.0, 3)],
[('Equalize', 0.2, 0), ('AutoContrast', 0.6, 0)],
[('Equalize', 0.2, 8), ('Equalize', 0.6, 4)],
[('Color', 0.9, 9), ('Equalize', 0.6, 6)],
[('AutoContrast', 0.8, 4), ('Solarize', 0.2, 8)],
[('Brightness', 0.1, 3), ('Color', 0.7, 0)],
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
[('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
[('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)],
]
return policy
@staticmethod
def policy_svhn():
"""Autoaugment policy for SVHN dataset.
Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy = [
[('ShearX', 0.9, 4), ('Invert', 0.2, 3)],
[('ShearY', 0.9, 8), ('Invert', 0.7, 5)],
[('Equalize', 0.6, 5), ('Solarize', 0.6, 6)],
[('Invert', 0.9, 3), ('Equalize', 0.6, 3)],
[('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
[('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)],
[('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
[('ShearY', 0.9, 5), ('Solarize', 0.2, 6)],
[('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)],
[('Equalize', 0.6, 3), ('Rotate', 0.9, 3)],
[('ShearX', 0.9, 4), ('Solarize', 0.3, 3)],
[('ShearY', 0.8, 8), ('Invert', 0.7, 4)],
[('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)],
[('Invert', 0.9, 4), ('Equalize', 0.6, 7)],
[('Contrast', 0.3, 3), ('Rotate', 0.8, 4)],
[('Invert', 0.8, 5), ('TranslateY', 0.0, 2)],
[('ShearY', 0.7, 6), ('Solarize', 0.4, 8)],
[('Invert', 0.6, 4), ('Rotate', 0.8, 4)],
[('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)],
[('ShearX', 0.1, 6), ('Invert', 0.6, 5)],
[('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)],
[('ShearY', 0.8, 4), ('Invert', 0.8, 8)],
[('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)],
[('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)],
[('ShearX', 0.7, 2), ('Invert', 0.1, 5)],
]
return policy
@staticmethod
def policy_reduced_imagenet():
"""Autoaugment policy for reduced ImageNet dataset.
Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy = [
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)]
]
return policy
@staticmethod @staticmethod
def policy_simple(): def policy_simple():
"""Same as `policy_v0`, except with custom ops removed.""" """Same as `policy_v0`, except with custom ops removed."""
......
...@@ -88,14 +88,24 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -88,14 +88,24 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
class AutoaugmentTest(tf.test.TestCase): class AutoaugmentTest(tf.test.TestCase):
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
]
def test_autoaugment(self): def test_autoaugment(self):
"""Smoke test to be sure there are no syntax errors.""" """Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8) image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment() for policy in self.AVAILABLE_POLICIES:
aug_image = augmenter.distort(image) augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape) self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug(self): def test_randaug(self):
"""Smoke test to be sure there are no syntax errors.""" """Smoke test to be sure there are no syntax errors."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment