"git@developer.sourcefind.cn:Wenxuan/LightX2V.git" did not exist on "407b85081e156e575fde594071bbddee660c40af"
Commit 6c1a6676 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 375526479
parent 3fa47266
...@@ -31,6 +31,7 @@ class RandAugment(hyperparams.Config): ...@@ -31,6 +31,7 @@ class RandAugment(hyperparams.Config):
magnitude: float = 10 magnitude: float = 10
cutout_const: float = 40 cutout_const: float = 40
translate_const: float = 10 translate_const: float = 10
prob_to_apply: Optional[float] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -109,7 +109,8 @@ class Parser(parser.Parser): ...@@ -109,7 +109,8 @@ class Parser(parser.Parser):
num_layers=aug_type.randaug.num_layers, num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude, magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const, cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const) translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply)
else: else:
raise ValueError('Augmentation policy {} not supported.'.format( raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type)) aug_type.type))
......
...@@ -1183,7 +1183,8 @@ class RandAugment(ImageAugment): ...@@ -1183,7 +1183,8 @@ class RandAugment(ImageAugment):
num_layers: int = 2, num_layers: int = 2,
magnitude: float = 10., magnitude: float = 10.,
cutout_const: float = 40., cutout_const: float = 40.,
translate_const: float = 100.): translate_const: float = 100.,
prob_to_apply: Optional[float] = None):
"""Applies the RandAugment policy to images. """Applies the RandAugment policy to images.
Args: Args:
...@@ -1195,6 +1196,8 @@ class RandAugment(ImageAugment): ...@@ -1195,6 +1196,8 @@ class RandAugment(ImageAugment):
[5, 10]. [5, 10].
cutout_const: multiplier for applying cutout. cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation. translate_const: multiplier for applying translation.
prob_to_apply: The probability to apply the selected augmentation at each
layer.
""" """
super(RandAugment, self).__init__() super(RandAugment, self).__init__()
...@@ -1202,6 +1205,7 @@ class RandAugment(ImageAugment): ...@@ -1202,6 +1205,7 @@ class RandAugment(ImageAugment):
self.magnitude = float(magnitude) self.magnitude = float(magnitude)
self.cutout_const = float(cutout_const) self.cutout_const = float(cutout_const)
self.translate_const = float(translate_const) self.translate_const = float(translate_const)
self.prob_to_apply = prob_to_apply
self.available_ops = [ self.available_ops = [
'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize', 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', 'Solarize',
'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
...@@ -1226,6 +1230,8 @@ class RandAugment(ImageAugment): ...@@ -1226,6 +1230,8 @@ class RandAugment(ImageAugment):
replace_value = [128] * 3 replace_value = [128] * 3
min_prob, max_prob = 0.2, 0.8 min_prob, max_prob = 0.2, 0.8
aug_image = image
for _ in range(self.num_layers): for _ in range(self.num_layers):
op_to_select = tf.random.uniform([], op_to_select = tf.random.uniform([],
maxval=len(self.available_ops) + 1, maxval=len(self.available_ops) + 1,
...@@ -1247,10 +1253,16 @@ class RandAugment(ImageAugment): ...@@ -1247,10 +1253,16 @@ class RandAugment(ImageAugment):
image, *selected_args))) image, *selected_args)))
# pylint:enable=g-long-lambda # pylint:enable=g-long-lambda
image = tf.switch_case( aug_image = tf.switch_case(
branch_index=op_to_select, branch_index=op_to_select,
branch_fns=branch_fns, branch_fns=branch_fns,
default=lambda: tf.identity(image)) default=lambda: tf.identity(image))
if self.prob_to_apply is not None:
aug_image = tf.cond(
tf.random.uniform(shape=[], dtype=tf.float32) < self.prob_to_apply,
lambda: tf.identity(aug_image), lambda: tf.identity(image))
image = aug_image
image = tf.cast(image, dtype=input_image_type) image = tf.cast(image, dtype=input_image_type)
return image return image
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment