Commit 68f301f7 authored by Hye Soo Yang's avatar Hye Soo Yang Committed by A. Unique TensorFlower
Browse files

[keras_cv] Support custom policy for `AutoAugment`.

PiperOrigin-RevId: 347069098
parent 590201bb
...@@ -18,9 +18,10 @@ AutoAugment Reference: https://arxiv.org/abs/1805.09501 ...@@ -18,9 +18,10 @@ AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719 RandAugment Reference: https://arxiv.org/abs/1909.13719
""" """
import math import math
from typing import Any, List, Optional, Text, Tuple, Iterable
import numpy as np
import tensorflow as tf import tensorflow as tf
from typing import Any, Dict, List, Optional, Text, Tuple
from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops from tensorflow.python.keras.layers.preprocessing import image_preprocessing as image_ops
...@@ -732,7 +733,8 @@ class AutoAugment(ImageAugment): ...@@ -732,7 +733,8 @@ class AutoAugment(ImageAugment):
def __init__(self, def __init__(self,
augmentation_name: Text = 'v0', augmentation_name: Text = 'v0',
policies: Optional[Dict[Text, Any]] = None, policies: Optional[Iterable[Iterable[Tuple[Text, float,
float]]]] = None,
cutout_const: float = 100, cutout_const: float = 100,
translate_const: float = 250): translate_const: float = 250):
"""Applies the AutoAugment policy to images. """Applies the AutoAugment policy to images.
...@@ -745,34 +747,66 @@ class AutoAugment(ImageAugment): ...@@ -745,34 +747,66 @@ class AutoAugment(ImageAugment):
the COCO dataset. `v1`, `v2` and `v3` are additional good policies found the COCO dataset. `v1`, `v2` and `v3` are additional good policies found
on the COCO dataset that have slight variation in what operations were on the COCO dataset that have slight variation in what operations were
used during the search procedure along with how many operations are used during the search procedure along with how many operations are
applied in parallel to a single image (2 vs 3). applied in parallel to a single image (2 vs 3). Make sure to set
`policies` to `None` (the default) if you want to set options using
`augmentation_name`.
policies: list of lists of tuples in the form `(func, prob, level)`, policies: list of lists of tuples in the form `(func, prob, level)`,
`func` is a string name of the augmentation function, `prob` is the `func` is a string name of the augmentation function, `prob` is the
probability of applying the `func` operation, `level` is the input probability of applying the `func` operation, `level` (or magnitude) is
argument for `func`. the input argument for `func`. For example:
```
[[('Equalize', 0.9, 3), ('Color', 0.7, 8)],
[('Invert', 0.6, 5), ('Rotate', 0.2, 9), ('ShearX', 0.1, 2)], ...]
```
The outer-most list must be 3-d. The number of operations in a
sub-policy can vary from one sub-policy to another.
If you provide `policies` as input, any option set with
`augmentation_name` will get overriden as they are mutually exclusive.
cutout_const: multiplier for applying cutout. cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation. translate_const: multiplier for applying translation.
Raises:
ValueError if `augmentation_name` is unsupported.
""" """
super(AutoAugment, self).__init__() super(AutoAugment, self).__init__()
if policies is None:
self.available_policies = {
'v0': self.policy_v0(),
'test': self.policy_test(),
'simple': self.policy_simple(),
'reduced_cifar10': self.policy_reduced_cifar10(),
'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(),
}
if augmentation_name not in self.available_policies:
raise ValueError(
'Invalid augmentation_name: {}'.format(augmentation_name))
self.augmentation_name = augmentation_name self.augmentation_name = augmentation_name
self.policies = self.available_policies[augmentation_name]
self.cutout_const = float(cutout_const) self.cutout_const = float(cutout_const)
self.translate_const = float(translate_const) self.translate_const = float(translate_const)
self.available_policies = {
'v0': self.policy_v0(),
'test': self.policy_test(),
'simple': self.policy_simple(),
'reduced_cifar10': self.policy_reduced_cifar10(),
'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(),
}
if not policies:
if augmentation_name not in self.available_policies:
raise ValueError(
'Invalid augmentation_name: {}'.format(augmentation_name))
self.policies = self.available_policies[augmentation_name]
else:
self._check_policy_shape(policies)
self.policies = policies
def _check_policy_shape(self, policies):
"""Checks dimension and shape of the custom policy.
Args:
policies: List of list of tuples in the form `(func, prob, level)`. Must
have shape of `(:, :, 3)`.
Raises:
ValueError if the shape of `policies` is unexpected.
"""
in_shape = np.array(policies).shape
if len(in_shape) != 3 or in_shape[-1:] != (3,):
raise ValueError('Wrong shape detected for custom policy. Expected '
'(:, :, 3) but got {}.'.format(in_shape))
def distort(self, image: tf.Tensor) -> tf.Tensor: def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the AutoAugment policy to `image`. """Applies the AutoAugment policy to `image`.
...@@ -803,9 +837,15 @@ class AutoAugment(ImageAugment): ...@@ -803,9 +837,15 @@ class AutoAugment(ImageAugment):
tf_policies = [] tf_policies = []
for policy in self.policies: for policy in self.policies:
tf_policy = [] tf_policy = []
assert_ranges = []
# Link string name to the correct python function and make sure the # Link string name to the correct python function and make sure the
# correct argument is passed into that function. # correct argument is passed into that function.
for policy_info in policy: for policy_info in policy:
_, prob, level = policy_info
assert_ranges.append(tf.Assert(tf.less_equal(prob, 1.), [prob]))
assert_ranges.append(
tf.Assert(tf.less_equal(level, int(_MAX_LEVEL)), [level]))
policy_info = list(policy_info) + [ policy_info = list(policy_info) + [
replace_value, self.cutout_const, self.translate_const replace_value, self.cutout_const, self.translate_const
] ]
...@@ -821,7 +861,8 @@ class AutoAugment(ImageAugment): ...@@ -821,7 +861,8 @@ class AutoAugment(ImageAugment):
return final_policy return final_policy
tf_policies.append(make_final_policy(tf_policy)) with tf.control_dependencies(assert_ranges):
tf_policies.append(make_final_policy(tf_policy))
image = select_and_apply_random_policy(tf_policies, image) image = select_and_apply_random_policy(tf_policies, image)
image = tf.cast(image, dtype=input_image_type) image = tf.cast(image, dtype=input_image_type)
......
...@@ -19,6 +19,7 @@ from __future__ import division ...@@ -19,6 +19,7 @@ from __future__ import division
# from __future__ import google_type_annotations # from __future__ import google_type_annotations
from __future__ import print_function from __future__ import print_function
import random
from absl.testing import parameterized from absl.testing import parameterized
import tensorflow as tf import tensorflow as tf
...@@ -86,7 +87,16 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -86,7 +87,16 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual(image, augment.rotate(image, degrees)) self.assertAllEqual(image, augment.rotate(image, degrees))
class AutoaugmentTest(tf.test.TestCase): class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
]
AVAILABLE_POLICIES = [ AVAILABLE_POLICIES = [
'v0', 'v0',
...@@ -135,6 +145,76 @@ class AutoaugmentTest(tf.test.TestCase): ...@@ -135,6 +145,76 @@ class AutoaugmentTest(tf.test.TestCase):
self.assertEqual((224, 224, 3), image.shape) self.assertEqual((224, 224, 3), image.shape)
def _generate_test_policy(self):
"""Generate a test policy at random."""
op_list = list(augment.NAME_TO_FUNC.keys())
size = 6
prob = [round(random.uniform(0., 1.), 1) for _ in range(size)]
mag = [round(random.uniform(0, 10)) for _ in range(size)]
policy = []
for i in range(0, size, 2):
policy.append([(op_list[i], prob[i], mag[i]),
(op_list[i + 1], prob[i + 1], mag[i + 1])])
return policy
def test_custom_policy(self):
"""Test autoaugment with a custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment(policies=self._generate_test_policy())
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
@parameterized.named_parameters(
{'testcase_name': '_OutOfRangeProb',
'sub_policy': ('Equalize', 1.1, 3), 'value': '1.1'},
{'testcase_name': '_OutOfRangeMag',
'sub_policy': ('Equalize', 0.9, 11), 'value': '11'},
)
def test_invalid_custom_sub_policy(self, sub_policy, value):
"""Test autoaugment with out-of-range values in the custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
policy = self._generate_test_policy()
policy[0][0] = sub_policy
augmenter = augment.AutoAugment(policies=policy)
with self.assertRaisesRegex(
tf.errors.InvalidArgumentError,
r'Expected \'tf.Tensor\(False, shape=\(\), dtype=bool\)\' to be true. '
r'Summarized data: ({})'.format(value)):
augmenter.distort(image)
def test_invalid_custom_policy_ndim(self):
"""Test autoaugment with wrong dimension in the custom policy."""
policy = [[('Equalize', 0.8, 1), ('Shear', 0.8, 4)],
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
policy = [[policy]]
with self.assertRaisesRegex(
ValueError,
r'Expected \(:, :, 3\) but got \(1, 1, 2, 2, 3\).'):
augment.AutoAugment(policies=policy)
def test_invalid_custom_policy_shape(self):
"""Test autoaugment with wrong shape in the custom policy."""
policy = [[('Equalize', 0.8, 1, 1), ('Shear', 0.8, 4, 1)],
[('TranslateY', 0.6, 3, 1), ('Rotate', 0.9, 3, 1)]]
with self.assertRaisesRegex(
ValueError,
r'Expected \(:, :, 3\) but got \(2, 2, 4\)'):
augment.AutoAugment(policies=policy)
def test_invalid_custom_policy_key(self):
"""Test autoaugment with invalid key in the custom policy."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
policy = [[('AAAAA', 0.8, 1), ('Shear', 0.8, 4)],
[('TranslateY', 0.6, 3), ('Rotate', 0.9, 3)]]
augmenter = augment.AutoAugment(policies=policy)
with self.assertRaisesRegex(KeyError, '\'AAAAA\''):
augmenter.distort(image)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment