"vscode:/vscode.git/clone" did not exist on "30072aec376a2fc6c6a465fe632bc7be050ec5ea"
Unverified Commit adf0c8e9 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge pull request #17 from srihari-humbarwadi/panoptic-deeplab-augmentation

Added AutoAugment
parents b9bd2f11 170eb70b
...@@ -44,6 +44,7 @@ class Parser(hyperparams.Config): ...@@ -44,6 +44,7 @@ class Parser(hyperparams.Config):
aug_scale_min: float = 1.0 aug_scale_min: float = 1.0
aug_scale_max: float = 1.0 aug_scale_max: float = 1.0
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_type: common.Augmentation = common.Augmentation()
sigma: float = 8.0 sigma: float = 8.0
small_instance_area_threshold: int = 4096 small_instance_area_threshold: int = 4096
small_instance_weight: float = 3.0 small_instance_weight: float = 3.0
...@@ -177,6 +178,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig: ...@@ -177,6 +178,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
stem_type = 'v0' stem_type = 'v0'
level = int(np.math.log2(output_stride)) level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig( config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig( runtime=cfg.RuntimeConfig(
mixed_precision_dtype='bfloat16', enable_xla=True), mixed_precision_dtype='bfloat16', enable_xla=True),
...@@ -248,6 +250,10 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig: ...@@ -248,6 +250,10 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
aug_scale_min=0.5, aug_scale_min=0.5,
aug_scale_max=1.5, aug_scale_max=1.5,
aug_rand_hflip=True, aug_rand_hflip=True,
aug_type=common.Augmentation(
type='autoaug',
autoaug=common.AutoAugment(
augmentation_name='panoptic_deeplab_policy')),
sigma=8.0, sigma=8.0,
small_instance_area_threshold=4096, small_instance_area_threshold=4096,
small_instance_weight=3.0)), small_instance_weight=3.0)),
...@@ -261,6 +267,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig: ...@@ -261,6 +267,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
aug_scale_min=1.0, aug_scale_min=1.0,
aug_scale_max=1.0, aug_scale_max=1.0,
aug_rand_hflip=False, aug_rand_hflip=False,
aug_type=None,
sigma=8.0, sigma=8.0,
small_instance_area_threshold=4096, small_instance_area_threshold=4096,
small_instance_weight=3.0), small_instance_weight=3.0),
......
...@@ -14,12 +14,16 @@ ...@@ -14,12 +14,16 @@
"""Data parser and processing for Panoptic Deeplab.""" """Data parser and processing for Panoptic Deeplab."""
from typing import List, Optional
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from official.vision.configs import common
from official.vision.dataloaders import parser from official.vision.dataloaders import parser
from official.vision.dataloaders import tf_example_decoder from official.vision.dataloaders import tf_example_decoder
from official.vision.ops import preprocess_ops from official.vision.ops import preprocess_ops
from official.vision.ops import augment
def _compute_gaussian_from_std(sigma): def _compute_gaussian_from_std(sigma):
...@@ -75,17 +79,18 @@ class Parser(parser.Parser): ...@@ -75,17 +79,18 @@ class Parser(parser.Parser):
def __init__( def __init__(
self, self,
output_size, output_size: List[int],
resize_eval_groundtruth=True, resize_eval_groundtruth: bool = True,
groundtruth_padded_size=None, groundtruth_padded_size: Optional[List[int]] = None,
ignore_label=0, ignore_label: int = 0,
aug_rand_hflip=False, aug_rand_hflip: bool = False,
aug_scale_min=1.0, aug_scale_min: float = 1.0,
aug_scale_max=1.0, aug_scale_max: float = 1.0,
sigma=8.0, aug_type: Optional[common.Augmentation] = None,
small_instance_area_threshold=4096, sigma: float = 8.0,
small_instance_weight=3.0, small_instance_area_threshold: int = 4096,
dtype='float32'): small_instance_weight: float = 3.0,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
Args: Args:
...@@ -104,6 +109,7 @@ class Parser(parser.Parser): ...@@ -104,6 +109,7 @@ class Parser(parser.Parser):
data augmentation during training. data augmentation during training.
aug_scale_max: `float`, the maximum scale applied to `output_size` for aug_scale_max: `float`, the maximum scale applied to `output_size` for
data augmentation during training. data augmentation during training.
aug_type: An optional Augmentation object with params for AutoAugment.
sigma: `float`, standard deviation for generating 2D Gaussian to encode sigma: `float`, standard deviation for generating 2D Gaussian to encode
centers. centers.
small_instance_area_threshold: `int`, small_instance_area_threshold: `int`,
...@@ -124,6 +130,18 @@ class Parser(parser.Parser): ...@@ -124,6 +130,18 @@ class Parser(parser.Parser):
self._aug_scale_min = aug_scale_min self._aug_scale_min = aug_scale_min
self._aug_scale_max = aug_scale_max self._aug_scale_max = aug_scale_max
if aug_type:
if aug_type.type == 'autoaug':
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
# dtype. # dtype.
self._dtype = dtype self._dtype = dtype
...@@ -134,7 +152,6 @@ class Parser(parser.Parser): ...@@ -134,7 +152,6 @@ class Parser(parser.Parser):
self._small_instance_area_threshold = small_instance_area_threshold self._small_instance_area_threshold = small_instance_area_threshold
self._small_instance_weight = small_instance_weight self._small_instance_weight = small_instance_weight
def _resize_and_crop_mask(self, mask, image_info, is_training): def _resize_and_crop_mask(self, mask, image_info, is_training):
"""Resizes and crops mask using `image_info` dict""" """Resizes and crops mask using `image_info` dict"""
height = image_info[0][0] height = image_info[0][0]
...@@ -156,7 +173,7 @@ class Parser(parser.Parser): ...@@ -156,7 +173,7 @@ class Parser(parser.Parser):
self._groundtruth_padded_size[0], self._groundtruth_padded_size[0],
self._groundtruth_padded_size[1]) self._groundtruth_padded_size[1])
mask -= 1 mask -= 1
# Assign ignore label to the padded region. # Assign ignore label to the padded region.
mask = tf.where( mask = tf.where(
tf.equal(mask, -1), tf.equal(mask, -1),
...@@ -167,6 +184,10 @@ class Parser(parser.Parser): ...@@ -167,6 +184,10 @@ class Parser(parser.Parser):
def _parse_data(self, data, is_training): def _parse_data(self, data, is_training):
image = data['image'] image = data['image']
if self._augmenter is not None and is_training:
image = self._augmenter.distort(image)
image = preprocess_ops.normalize_image(image) image = preprocess_ops.normalize_image(image)
category_mask = tf.cast( category_mask = tf.cast(
...@@ -255,7 +276,7 @@ class Parser(parser.Parser): ...@@ -255,7 +276,7 @@ class Parser(parser.Parser):
padding_start = int(3 * self._sigma + 1) padding_start = int(3 * self._sigma + 1)
padding_end = int(3 * self._sigma + 2) padding_end = int(3 * self._sigma + 2)
# padding should be equal to self._gaussian_size which is calculated # padding should be equal to self._gaussian_size which is calculated
# as size = int(6 * sigma + 3) # as size = int(6 * sigma + 3)
padding = padding_start + padding_end padding = padding_start + padding_end
...@@ -276,7 +297,7 @@ class Parser(parser.Parser): ...@@ -276,7 +297,7 @@ class Parser(parser.Parser):
unique_instance_ids, _ = tf.unique(tf.reshape(instance_mask, [-1])) unique_instance_ids, _ = tf.unique(tf.reshape(instance_mask, [-1]))
# The following method for encoding center heatmaps and offets is inspired # The following method for encoding center heatmaps and offets is inspired
# by the reference implementation available at # by the reference implementation available at
# https://github.com/google-research/deeplab2/blob/main/data/sample_generator.py # pylint: disable=line-too-long # https://github.com/google-research/deeplab2/blob/main/data/sample_generator.py # pylint: disable=line-too-long
for instance_id in unique_instance_ids: for instance_id in unique_instance_ids:
if instance_id == self._ignore_label: if instance_id == self._ignore_label:
...@@ -327,7 +348,7 @@ class Parser(parser.Parser): ...@@ -327,7 +348,7 @@ class Parser(parser.Parser):
instance_centers_offset = tf.stack( instance_centers_offset = tf.stack(
[centers_offset_y, centers_offset_x], [centers_offset_y, centers_offset_x],
axis=-1) axis=-1)
return (instance_centers_heatmap, return (instance_centers_heatmap,
instance_centers_offset, instance_centers_offset,
semantic_weights) semantic_weights)
...@@ -102,6 +102,7 @@ class PanopticDeeplabTask(base_task.Task): ...@@ -102,6 +102,7 @@ class PanopticDeeplabTask(base_task.Task):
aug_scale_min=params.parser.aug_scale_min, aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max, aug_scale_max=params.parser.aug_scale_max,
aug_rand_hflip=params.parser.aug_rand_hflip, aug_rand_hflip=params.parser.aug_rand_hflip,
aug_type=params.parser.aug_type,
sigma=params.parser.sigma, sigma=params.parser.sigma,
dtype=params.parser.dtype) dtype=params.parser.dtype)
......
...@@ -1583,6 +1583,7 @@ class AutoAugment(ImageAugment): ...@@ -1583,6 +1583,7 @@ class AutoAugment(ImageAugment):
'reduced_cifar10': self.policy_reduced_cifar10(), 'reduced_cifar10': self.policy_reduced_cifar10(),
'svhn': self.policy_svhn(), 'svhn': self.policy_svhn(),
'reduced_imagenet': self.policy_reduced_imagenet(), 'reduced_imagenet': self.policy_reduced_imagenet(),
'panoptic_deeplab_policy': self.panoptic_deeplab_policy(),
} }
if not policies: if not policies:
...@@ -1888,6 +1889,16 @@ class AutoAugment(ImageAugment): ...@@ -1888,6 +1889,16 @@ class AutoAugment(ImageAugment):
] ]
return policy return policy
@staticmethod
def panoptic_deeplab_policy():
policy = [
[('Sharpness', 0.4, 1.4), ('Brightness', 0.2, 2.0)],
[('Equalize', 0.0, 1.8), ('Contrast', 0.2, 2.0)],
[('Sharpness', 0.2, 1.8), ('Color', 0.2, 1.8)],
[('Solarize', 0.2, 1.4), ('Equalize', 0.6, 1.8)],
[('Sharpness', 0.2, 0.2), ('Equalize', 0.2, 1.4)]]
return policy
@staticmethod @staticmethod
def policy_test(): def policy_test():
"""Autoaugment test policy for debugging.""" """Autoaugment test policy for debugging."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment