Commit c1097033 authored by Hye Soo Yang's avatar Hye Soo Yang Committed by A. Unique TensorFlower
Browse files

[keras_cv] Add more AutoAugment policies (svhn, reduced imagenet, reduced cifar10).

PiperOrigin-RevId: 347020444
parent dc2eebf8
......@@ -739,7 +739,8 @@ class AutoAugment(ImageAugment):
Args:
augmentation_name: The name of the AutoAugment policy to use. The
available options are `v0` and `test`. `v0` is the policy used for all
available options are `v0`, `test`, `reduced_cifar10`, `svhn` and
`reduced_imagenet`. `v0` is the policy used for all
of the results in the paper and was found to achieve the best results on
the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
on the COCO dataset that have slight variation in what operations were
......@@ -759,6 +760,9 @@ class AutoAugment(ImageAugment):
'v0': self.policy_v0(),
'test': self.policy_test(),
'simple': self.policy_simple(),
'reduced_cifar10': self.policy_reduced_cifar10(),
'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(),
}
if augmentation_name not in self.available_policies:
......@@ -868,6 +872,132 @@ class AutoAugment(ImageAugment):
]
return policy
@staticmethod
def policy_reduced_cifar10():
"""Autoaugment policy for reduced CIFAR-10 dataset.
Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy = [
[('Invert', 0.1, 7), ('Contrast', 0.2, 6)],
[('Rotate', 0.7, 2), ('TranslateX', 0.3, 9)],
[('Sharpness', 0.8, 1), ('Sharpness', 0.9, 3)],
[('ShearY', 0.5, 8), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.5, 8), ('Equalize', 0.9, 2)],
[('ShearY', 0.2, 7), ('Posterize', 0.3, 7)],
[('Color', 0.4, 3), ('Brightness', 0.6, 7)],
[('Sharpness', 0.3, 9), ('Brightness', 0.7, 9)],
[('Equalize', 0.6, 5), ('Equalize', 0.5, 1)],
[('Contrast', 0.6, 7), ('Sharpness', 0.6, 5)],
[('Color', 0.7, 7), ('TranslateX', 0.5, 8)],
[('Equalize', 0.3, 7), ('AutoContrast', 0.4, 8)],
[('TranslateY', 0.4, 3), ('Sharpness', 0.2, 6)],
[('Brightness', 0.9, 6), ('Color', 0.2, 8)],
[('Solarize', 0.5, 2), ('Invert', 0.0, 3)],
[('Equalize', 0.2, 0), ('AutoContrast', 0.6, 0)],
[('Equalize', 0.2, 8), ('Equalize', 0.6, 4)],
[('Color', 0.9, 9), ('Equalize', 0.6, 6)],
[('AutoContrast', 0.8, 4), ('Solarize', 0.2, 8)],
[('Brightness', 0.1, 3), ('Color', 0.7, 0)],
[('Solarize', 0.4, 5), ('AutoContrast', 0.9, 3)],
[('TranslateY', 0.9, 9), ('TranslateY', 0.7, 9)],
[('AutoContrast', 0.9, 2), ('Solarize', 0.8, 3)],
[('Equalize', 0.8, 8), ('Invert', 0.1, 3)],
[('TranslateY', 0.7, 9), ('AutoContrast', 0.9, 1)],
]
return policy
@staticmethod
def policy_svhn():
"""Autoaugment policy for SVHN dataset.
Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy = [
[('ShearX', 0.9, 4), ('Invert', 0.2, 3)],
[('ShearY', 0.9, 8), ('Invert', 0.7, 5)],
[('Equalize', 0.6, 5), ('Solarize', 0.6, 6)],
[('Invert', 0.9, 3), ('Equalize', 0.6, 3)],
[('Equalize', 0.6, 1), ('Rotate', 0.9, 3)],
[('ShearX', 0.9, 4), ('AutoContrast', 0.8, 3)],
[('ShearY', 0.9, 8), ('Invert', 0.4, 5)],
[('ShearY', 0.9, 5), ('Solarize', 0.2, 6)],
[('Invert', 0.9, 6), ('AutoContrast', 0.8, 1)],
[('Equalize', 0.6, 3), ('Rotate', 0.9, 3)],
[('ShearX', 0.9, 4), ('Solarize', 0.3, 3)],
[('ShearY', 0.8, 8), ('Invert', 0.7, 4)],
[('Equalize', 0.9, 5), ('TranslateY', 0.6, 6)],
[('Invert', 0.9, 4), ('Equalize', 0.6, 7)],
[('Contrast', 0.3, 3), ('Rotate', 0.8, 4)],
[('Invert', 0.8, 5), ('TranslateY', 0.0, 2)],
[('ShearY', 0.7, 6), ('Solarize', 0.4, 8)],
[('Invert', 0.6, 4), ('Rotate', 0.8, 4)],
[('ShearY', 0.3, 7), ('TranslateX', 0.9, 3)],
[('ShearX', 0.1, 6), ('Invert', 0.6, 5)],
[('Solarize', 0.7, 2), ('TranslateY', 0.6, 7)],
[('ShearY', 0.8, 4), ('Invert', 0.8, 8)],
[('ShearX', 0.7, 9), ('TranslateY', 0.8, 3)],
[('ShearY', 0.8, 5), ('AutoContrast', 0.7, 3)],
[('ShearX', 0.7, 2), ('Invert', 0.1, 5)],
]
return policy
@staticmethod
def policy_reduced_imagenet():
"""Autoaugment policy for reduced ImageNet dataset.
Result is from the AutoAugment paper: https://arxiv.org/abs/1805.09501.
Each tuple is an augmentation operation of the form
(operation, probability, magnitude). Each element in policy is a
sub-policy that will be applied sequentially on the image.
Returns:
the policy.
"""
policy = [
[('Posterize', 0.4, 8), ('Rotate', 0.6, 9)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Equalize', 0.4, 4), ('Rotate', 0.8, 8)],
[('Solarize', 0.6, 3), ('Equalize', 0.6, 7)],
[('Posterize', 0.8, 5), ('Equalize', 1.0, 2)],
[('Rotate', 0.2, 3), ('Solarize', 0.6, 8)],
[('Equalize', 0.6, 8), ('Posterize', 0.4, 6)],
[('Rotate', 0.8, 8), ('Color', 0.4, 0)],
[('Rotate', 0.4, 9), ('Equalize', 0.6, 2)],
[('Equalize', 0.0, 7), ('Equalize', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Rotate', 0.8, 8), ('Color', 1.0, 2)],
[('Color', 0.8, 8), ('Solarize', 0.8, 7)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8)],
[('ShearX', 0.6, 5), ('Equalize', 1.0, 9)],
[('Color', 0.4, 0), ('Equalize', 0.6, 3)],
[('Equalize', 0.4, 7), ('Solarize', 0.2, 4)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8)],
[('Color', 0.6, 4), ('Contrast', 1.0, 8)],
[('Equalize', 0.8, 8), ('Equalize', 0.6, 3)]
]
return policy
@staticmethod
def policy_simple():
"""Same as `policy_v0`, except with custom ops removed."""
......
......@@ -88,14 +88,24 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
class AutoaugmentTest(tf.test.TestCase):
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
]
def test_autoaugment(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment()
aug_image = augmenter.distort(image)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug(self):
"""Smoke test to be sure there are no syntax errors."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment