"src/graph/vscode:/vscode.git/clone" did not exist on "cce31e9a26c5459631fe945db444f285bb00b730"
Commit 3f1ca33a authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add data augmentation strategies to mitigate overfitting in ViTs.

PiperOrigin-RevId: 468079502
parent aa04c2de
...@@ -36,6 +36,8 @@ class Parser(hyperparams.Config): ...@@ -36,6 +36,8 @@ class Parser(hyperparams.Config):
aug_rand_hflip: bool = False aug_rand_hflip: bool = False
aug_scale_min: float = 1.0 aug_scale_min: float = 1.0
aug_scale_max: float = 1.0 aug_scale_max: float = 1.0
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
skip_crowd_during_training: bool = True skip_crowd_during_training: bool = True
max_num_instances: int = 100 max_num_instances: int = 100
rpn_match_threshold: float = 0.7 rpn_match_threshold: float = 0.7
......
...@@ -14,13 +14,16 @@ ...@@ -14,13 +14,16 @@
"""Data parser and processing for Mask R-CNN.""" """Data parser and processing for Mask R-CNN."""
# Import libraries from typing import Optional
# Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.configs import common
from official.vision.dataloaders import parser from official.vision.dataloaders import parser
from official.vision.dataloaders import utils from official.vision.dataloaders import utils
from official.vision.ops import anchor from official.vision.ops import anchor
from official.vision.ops import augment
from official.vision.ops import box_ops from official.vision.ops import box_ops
from official.vision.ops import preprocess_ops from official.vision.ops import preprocess_ops
...@@ -42,6 +45,7 @@ class Parser(parser.Parser): ...@@ -42,6 +45,7 @@ class Parser(parser.Parser):
aug_rand_hflip=False, aug_rand_hflip=False,
aug_scale_min=1.0, aug_scale_min=1.0,
aug_scale_max=1.0, aug_scale_max=1.0,
aug_type: Optional[common.Augmentation] = None,
skip_crowd_during_training=True, skip_crowd_during_training=True,
max_num_instances=100, max_num_instances=100,
include_mask=False, include_mask=False,
...@@ -73,6 +77,9 @@ class Parser(parser.Parser): ...@@ -73,6 +77,9 @@ class Parser(parser.Parser):
data augmentation during training. data augmentation during training.
aug_scale_max: `float`, the maximum scale applied to `output_size` for aug_scale_max: `float`, the maximum scale applied to `output_size` for
data augmentation during training. data augmentation during training.
aug_type: An optional Augmentation object with params for AutoAugment.
The AutoAug policy should not use rotation/translation/shear.
Only in-place augmentations can be used.
skip_crowd_during_training: `bool`, if True, skip annotations labeled with skip_crowd_during_training: `bool`, if True, skip annotations labeled with
`is_crowd` equals to 1. `is_crowd` equals to 1.
max_num_instances: `int` number of maximum number of instances in an max_num_instances: `int` number of maximum number of instances in an
...@@ -104,6 +111,26 @@ class Parser(parser.Parser): ...@@ -104,6 +111,26 @@ class Parser(parser.Parser):
self._aug_scale_min = aug_scale_min self._aug_scale_min = aug_scale_min
self._aug_scale_max = aug_scale_max self._aug_scale_max = aug_scale_max
if aug_type and aug_type.type:
if aug_type.type == 'autoaug':
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
self._augmenter = augment.RandAugment(
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
# Mask. # Mask.
self._include_mask = include_mask self._include_mask = include_mask
self._mask_crop_size = mask_crop_size self._mask_crop_size = mask_crop_size
...@@ -167,6 +194,9 @@ class Parser(parser.Parser): ...@@ -167,6 +194,9 @@ class Parser(parser.Parser):
# Gets original image and its size. # Gets original image and its size.
image = data['image'] image = data['image']
if self._augmenter is not None:
image = self._augmenter.distort(image)
image_shape = tf.shape(image)[0:2] image_shape = tf.shape(image)[0:2]
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
......
...@@ -1623,6 +1623,7 @@ class AutoAugment(ImageAugment): ...@@ -1623,6 +1623,7 @@ class AutoAugment(ImageAugment):
'svhn': self.policy_svhn(), 'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(), 'reduced_imagenet': self.policy_reduced_imagenet(),
'panoptic_deeplab_policy': self.panoptic_deeplab_policy(), 'panoptic_deeplab_policy': self.panoptic_deeplab_policy(),
'vit': self.vit(),
} }
if not policies: if not policies:
...@@ -1938,6 +1939,22 @@ class AutoAugment(ImageAugment): ...@@ -1938,6 +1939,22 @@ class AutoAugment(ImageAugment):
[('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4)]] [('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4)]]
return policy return policy
@staticmethod
def vit():
"""Autoaugment policy for a generic ViT."""
policy = [
[('Sharpness', 0.4, 1.4), ('Brightness', 0.2, 2.0), ('Cutout', 0.8, 8)],
[('Equalize', 0.0, 1.8), ('Contrast', 0.2, 2.0), ('Cutout', 0.8, 8)],
[('Sharpness', 0.2, 1.8), ('Color', 0.2, 1.8), ('Cutout', 0.8, 8)],
[('Solarize', 0.2, 1.4), ('Equalize', 0.6, 1.8), ('Cutout', 0.8, 8)],
[('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4), ('Cutout', 0.8, 8)],
[('Sharpness', 0.4, 7), ('Invert', 0.6, 8), ('Cutout', 0.8, 8)],
[('Invert', 0.6, 4), ('Equalize', 1.0, 8), ('Cutout', 0.8, 8)],
[('Posterize', 0.6, 7), ('Posterize', 0.6, 6), ('Cutout', 0.8, 8)],
[('Solarize', 0.6, 5), ('AutoContrast', 0.6, 5), ('Cutout', 0.8, 8)],
]
return policy
@staticmethod @staticmethod
def policy_test(): def policy_test():
"""Autoaugment test policy for debugging.""" """Autoaugment test policy for debugging."""
......
...@@ -96,6 +96,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -96,6 +96,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'svhn', 'svhn',
'reduced_imagenet', 'reduced_imagenet',
'detection_v0', 'detection_v0',
'vit',
] ]
def test_autoaugment(self): def test_autoaugment(self):
......
...@@ -155,6 +155,7 @@ class MaskRCNNTask(base_task.Task): ...@@ -155,6 +155,7 @@ class MaskRCNNTask(base_task.Task):
aug_rand_hflip=params.parser.aug_rand_hflip, aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min, aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max, aug_scale_max=params.parser.aug_scale_max,
aug_type=params.parser.aug_type,
skip_crowd_during_training=params.parser.skip_crowd_during_training, skip_crowd_during_training=params.parser.skip_crowd_during_training,
max_num_instances=params.parser.max_num_instances, max_num_instances=params.parser.max_num_instances,
include_mask=self._task_config.model.include_mask, include_mask=self._task_config.model.include_mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment