Commit 3f1ca33a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add data augmentation strategies to mitigate overfitting in ViTs.

PiperOrigin-RevId: 468079502
parent aa04c2de
......@@ -36,6 +36,8 @@ class Parser(hyperparams.Config):
aug_rand_hflip: bool = False
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
skip_crowd_during_training: bool = True
max_num_instances: int = 100
rpn_match_threshold: float = 0.7
......
......@@ -14,13 +14,16 @@
"""Data parser and processing for Mask R-CNN."""
# Import libraries
from typing import Optional
# Import libraries
import tensorflow as tf
from official.vision.configs import common
from official.vision.dataloaders import parser
from official.vision.dataloaders import utils
from official.vision.ops import anchor
from official.vision.ops import augment
from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops
......@@ -42,6 +45,7 @@ class Parser(parser.Parser):
aug_rand_hflip=False,
aug_scale_min=1.0,
aug_scale_max=1.0,
aug_type: Optional[common.Augmentation] = None,
skip_crowd_during_training=True,
max_num_instances=100,
include_mask=False,
......@@ -73,6 +77,9 @@ class Parser(parser.Parser):
data augmentation during training.
aug_scale_max: `float`, the maximum scale applied to `output_size` for
data augmentation during training.
aug_type: An optional Augmentation object with params for AutoAugment.
The AutoAug policy should not use rotation/translation/shear.
Only in-place augmentations can be used.
skip_crowd_during_training: `bool`, if True, skip annotations labeled with
`is_crowd` equals to 1.
max_num_instances: `int` number of maximum number of instances in an
......@@ -104,6 +111,26 @@ class Parser(parser.Parser):
self._aug_scale_min = aug_scale_min
self._aug_scale_max = aug_scale_max
if aug_type and aug_type.type:
if aug_type.type == 'autoaug':
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
self._augmenter = augment.RandAugment(
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
# Mask.
self._include_mask = include_mask
self._mask_crop_size = mask_crop_size
......@@ -167,6 +194,9 @@ class Parser(parser.Parser):
# Gets original image and its size.
image = data['image']
if self._augmenter is not None:
image = self._augmenter.distort(image)
image_shape = tf.shape(image)[0:2]
# Normalizes image with mean and std pixel values.
......
......@@ -1623,6 +1623,7 @@ class AutoAugment(ImageAugment):
'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(),
'panoptic_deeplab_policy': self.panoptic_deeplab_policy(),
'vit': self.vit(),
}
if not policies:
......@@ -1938,6 +1939,22 @@ class AutoAugment(ImageAugment):
[('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4)]]
return policy
@staticmethod
def vit():
"""Autoaugment policy for a generic ViT."""
policy = [
[('Sharpness', 0.4, 1.4), ('Brightness', 0.2, 2.0), ('Cutout', 0.8, 8)],
[('Equalize', 0.0, 1.8), ('Contrast', 0.2, 2.0), ('Cutout', 0.8, 8)],
[('Sharpness', 0.2, 1.8), ('Color', 0.2, 1.8), ('Cutout', 0.8, 8)],
[('Solarize', 0.2, 1.4), ('Equalize', 0.6, 1.8), ('Cutout', 0.8, 8)],
[('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4), ('Cutout', 0.8, 8)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8), ('Cutout', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8), ('Cutout', 0.8, 8)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6), ('Cutout', 0.8, 8)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5), ('Cutout', 0.8, 8)],
]
return policy
@staticmethod
def policy_test():
"""Autoaugment test policy for debugging."""
......
......@@ -96,6 +96,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'svhn',
'reduced_imagenet',
'detection_v0',
'vit',
]
def test_autoaugment(self):
......
......@@ -155,6 +155,7 @@ class MaskRCNNTask(base_task.Task):
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
aug_type=params.parser.aug_type,
skip_crowd_during_training=params.parser.skip_crowd_during_training,
max_num_instances=params.parser.max_num_instances,
include_mask=self._task_config.model.include_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment