Commit 68c04c17 authored by Hye Soo Yang's avatar Hye Soo Yang Committed by A. Unique TensorFlower
Browse files

[keras_cv] Support custom policy for `AutoAugment`.

PiperOrigin-RevId: 347069098
parent c1097033
......@@ -18,9 +18,10 @@ AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
"""
import math
from typing import Any, List, Optional, Text, Tuple, Iterable
import numpy as np
import tensorflow as tf
from typing import Any, Dict, List, Optional, Text, Tuple
from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
......@@ -732,7 +733,8 @@ class AutoAugment(ImageAugment):
def __init__(self,
augmentation_name: Text = 'v0',
policies: Optional[Dict[Text, Any]] = None,
policies: Optional[Iterable[Iterable[Tuple[Text, float,
float]]]] = None,
cutout_const: float = 100,
translate_const: float = 250):
"""Applies the AutoAugment policy to images.
......@@ -745,34 +747,66 @@ class AutoAugment(ImageAugment):
the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
on the COCO dataset that have slight variation in what operations were
used during the search procedure along with how many operations are
applied in parallel to a single image (2 vs 3).
applied in parallel to a single image (2 vs 3). Make sure to set
`policies` to `None` (the default) if you want to set options using
`augmentation_name`.
policies: list of lists of tuples in the form `(func, prob, level)`,
`func` is a string name of the augmentation function, `prob` is the
probability of applying the `func` operation, `level` is the input
argument for `func`.
probability of applying the `func` operation, `level` (or magnitude) is
the input argument for `func`. For example:
```
[[('Equalize', 0.9, 3), ('Color', 0.7, 8)],
[('Invert', 0.6, 5), ('Rotate', 0.2, 9), ('ShearX', 0.1, 2)], ...]
```
The outer-most list must be 3-d. The number of operations in a
sub-policy can vary from one sub-policy to another.
If you provide `policies` as input, any option set with
`augmentation_name` will get overriden as they are mutually exclusive.
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
Raises:
ValueError if `augmentation_name` is unsupported.
"""
super(AutoAugment, self).__init__()
if policies is None:
self.available_policies = {
'v0': self.policy_v0(),
'test': self.policy_test(),
'simple': self.policy_simple(),
'reduced_cifar10': self.policy_reduced_cifar10(),
'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(),
}
if augmentation_name not in self.available_policies:
raise ValueError(
'Invalid augmentation_name: {}'.format(augmentation_name))
self.augmentation_name = augmentation_name
self.policies = self.available_policies[augmentation_name]
self.cutout_const = float(cutout_const)
self.translate_const = float(translate_const)
self.available_policies = {
'v0': self.policy_v0(),
'test': self.policy_test(),
'simple': self.policy_simple(),
'reduced_cifar10': self.policy_reduced_cifar10(),
'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(),
}
if not policies:
if augmentation_name not in self.available_policies:
raise ValueError(
'Invalid augmentation_name: {}'.format(augmentation_name))
self.policies = self.available_policies[augmentation_name]
else:
self._check_policy_shape(policies)
self.policies = policies
def _check_policy_shape(self, policies):
"""Checks dimension and shape of the custom policy.
Args:
policies: List of list of tuples in the form `(func, prob, level)`. Must
have shape of `(:, :, 3)`.
Raises:
ValueError if the shape of `policies` is unexpected.
"""
in_shape = np.array(policies).shape
if len(in_shape) != 3 or in_shape[-1:] != (3,):
raise ValueError('Wrong shape detected for custom policy. Expected '
'(:, :, 3) but got {}.'.format(in_shape))
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the AutoAugment policy to `image`.
......@@ -803,9 +837,15 @@ class AutoAugment(ImageAugment):
tf_policies = []
for policy in self.policies:
tf_policy = []
assert_ranges = []
# Link string name to the correct python function and make sure the
# correct argument is passed into that function.
for policy_info in policy:
_, prob, level = policy_info
assert_ranges.append(tf.Assert(tf.less_equal(prob, 1.), [prob]))
assert_ranges.append(
tf.Assert(tf.less_equal(level, int(_MAX_LEVEL)), [level]))
policy_info = list(policy_info) + [
replace_value, self.cutout_const, self.translate_const
]
......@@ -821,7 +861,8 @@ class AutoAugment(ImageAugment):
return final_policy
tf_policies.append(make_final_policy(tf_policy))
with tf.control_dependencies(assert_ranges):
tf_policies.append(make_final_policy(tf_policy))
image = select_and_apply_random_policy(tf_policies, image)
image = tf.cast(image, dtype=input_image_type)
......
......@@ -19,6 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import random
from absl.testing import parameterized
import tensorflow as tf
......@@ -86,7 +87,16 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(image, augment.rotate(image, degrees))
class AutoaugmentTest(tf.test.TestCase):
class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
]
AVAILABLE_POLICIES = [
'v0',
......@@ -135,6 +145,76 @@ class AutoaugmentTest(tf.test.TestCase):
self.assertEqual((224, 224, 3), image.shape)
def _generate_test_policy(self):
"""Generate a test policy at random."""
op_list = list(augment.NAME_TO_FUNC.keys())
size = 6
prob = [round(random.uniform(0., 1.), 1) for _ in range(size)]
mag = [round(random.uniform(0, 10)) for _ in range(size)]
policy = []
for i in range(0, size, 2):
policy.append([(op_list[i], prob[i], mag[i]),
(op_list[i + 1], prob[i + 1], mag[i + 1])])
return policy
def test_custom_policy(self):
"""Test autoaugment with a custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment(policies=self._generate_test_policy())
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
@parameterized.named_parameters(
{'testcase_name': '_OutOfRangeProb',
'sub_policy': ('Equalize', 1.1, 3), 'value': '1.1'},
{'testcase_name': '_OutOfRangeMag',
'sub_policy': ('Equalize', 0.9, 11), 'value': '11'},
)
def test_invalid_custom_sub_policy(self, sub_policy, value):
"""Test autoaugment with out-of-range values in the custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
policy = self._generate_test_policy()
policy[0][0] = sub_policy
augmenter = augment.AutoAugment(policies=policy)
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
r'Expected \'tf.Tensor\(False, shape=\(\), dtype=bool\)\' to be true. '
r'Summarized data: ({})'.format(value)):
augmenter.distort(image)
def test_invalid_custom_policy_ndim(self):
"""Test autoaugment with wrong dimension in the custom policy."""
policy = [[('Equalize', 0.8, 1), ('Shear', 0.8, 4)],
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
policy = [[policy]]
with self.assertRaisesRegex(
ValueError,
r'Expected \(:, :, 3\) but got \(1, 1, 2, 2, 3\).'):
augment.AutoAugment(policies=policy)
def test_invalid_custom_policy_shape(self):
"""Test autoaugment with wrong shape in the custom policy."""
policy = [[('Equalize', 0.8, 1, 1), ('Shear', 0.8, 4, 1)],
[('TranslateY', 0.6, 3, 1), ('Rotate', 0.9, 3, 1)]]
with self.assertRaisesRegex(
ValueError,
r'Expected \(:, :, 3\) but got \(2, 2, 4\)'):
augment.AutoAugment(policies=policy)
def test_invalid_custom_policy_key(self):
"""Test autoaugment with invalid key in the custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
policy = [[('AAAAA', 0.8, 1), ('Shear', 0.8, 4)],
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
augmenter = augment.AutoAugment(policies=policy)
with self.assertRaisesRegex(KeyError, '\'AAAAA\''):
augmenter.distort(image)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment