Commit 6c1a6676 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 375526479
parent 3fa47266
......@@ -31,6 +31,7 @@ class RandAugment(hyperparams.Config):
magnitude: float = 10
cutout_const: float = 40
translate_const: float = 10
prob_to_apply: Optional[float] = None
@dataclasses.dataclass
......
......@@ -109,7 +109,8 @@ class Parser(parser.Parser):
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const)
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
......
......@@ -1183,7 +1183,8 @@ class RandAugment(ImageAugment):
num_layers: int = 2,
magnitude: float = 10.,
cutout_const: float = 40.,
translate_const: float = 100.):
translate_const: float = 100.,
prob_to_apply: Optional[float] = None):
"""Applies the RandAugment policy to images.
Args:
......@@ -1195,6 +1196,8 @@ class RandAugment(ImageAugment):
[5, 10].
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
prob_to_apply: The probability to apply the selected augmentation at each
layer.
"""
super(RandAugment, self).__init__()
......@@ -1202,6 +1205,7 @@ class RandAugment(ImageAugment):
self.magnitude = float(magnitude)
self.cutout_const = float(cutout_const)
self.translate_const = float(translate_const)
self.prob_to_apply = prob_to_apply
self.available_ops = [
'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize',
'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
......@@ -1226,6 +1230,8 @@ class RandAugment(ImageAugment):
replace_value = [128] * 3
min_prob, max_prob = 0.2, 0.8
aug_image = image
for _ in range(self.num_layers):
op_to_select = tf.random.uniform([],
maxval=len(self.available_ops) + 1,
......@@ -1247,10 +1253,16 @@ class RandAugment(ImageAugment):
image, *selected_args)))
# pylint:enable=g-long-lambda
image = tf.switch_case(
aug_image = tf.switch_case(
branch_index=op_to_select,
branch_fns=branch_fns,
default=lambda: tf.identity(image))
if self.prob_to_apply is not None:
aug_image = tf.cond(
tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply,
lambda: tf.identity(aug_image), lambda: tf.identity(image))
image = aug_image
image = tf.cast(image, dtype=input_image_type)
return image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment