"lmdeploy/vscode:/vscode.git/clone" did not exist on "cde17e73e6dc3ce3e54182187f1612298a5558f8"
Unverified Commit a7b55e81 authored by srihari-humbarwadi's avatar srihari-humbarwadi
Browse files

enable `AutoAugment` augmentation

parent 0692e42f
......@@ -44,6 +44,7 @@ class Parser(hyperparams.Config):
aug_scale_min: float = 1.0
aug_scale_max: float = 1.0
aug_rand_hflip: bool = True
aug_type: common.Augmentation = common.Augmentation()
sigma: float = 8.0
small_instance_area_threshold: int = 4096
small_instance_weight: float = 3.0
......@@ -177,9 +178,10 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
stem_type = 'v0'
level = int(np.math.log2(output_stride))
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype='bfloat16', enable_xla=True),
mixed_precision_dtype='float32', enable_xla=True),
task=PanopticDeeplabTask(
init_checkpoint='gs://cloud-tpu-checkpoints/vision-2.0/deeplab/deeplab_resnet50_imagenet/ckpt-62400', # pylint: disable=line-too-long
init_checkpoint_modules=['backbone'],
......@@ -248,6 +250,10 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
aug_scale_min=0.5,
aug_scale_max=1.5,
aug_rand_hflip=True,
aug_type=common.Augmentation(
type='autoaug',
autoaug=common.AutoAugment(
augmentation_name='panoptic_deeplab_policy')),
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0)),
......@@ -261,6 +267,7 @@ def panoptic_deeplab_coco() -> cfg.ExperimentConfig:
aug_scale_min=1.0,
aug_scale_max=1.0,
aug_rand_hflip=False,
aug_type=None,
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0),
......
......@@ -14,12 +14,16 @@
"""Data parser and processing for Panoptic Deeplab."""
from typing import List, Optional
import numpy as np
import tensorflow as tf
from official.vision.configs import common
from official.vision.dataloaders import parser
from official.vision.dataloaders import tf_example_decoder
from official.vision.ops import preprocess_ops
from official.vision.ops import augment
def _compute_gaussian_from_std(sigma):
......@@ -75,17 +79,18 @@ class Parser(parser.Parser):
def __init__(
self,
output_size,
resize_eval_groundtruth=True,
groundtruth_padded_size=None,
ignore_label=0,
aug_rand_hflip=False,
aug_scale_min=1.0,
aug_scale_max=1.0,
sigma=8.0,
small_instance_area_threshold=4096,
small_instance_weight=3.0,
dtype='float32'):
output_size: List[int],
resize_eval_groundtruth: bool = True,
groundtruth_padded_size: Optional[List[int]] = None,
ignore_label: int = 0,
aug_rand_hflip: bool = False,
aug_scale_min: float = 1.0,
aug_scale_max: float = 1.0,
aug_type: Optional[common.Augmentation] = None,
sigma: float = 8.0,
small_instance_area_threshold: int = 4096,
small_instance_weight: float = 3.0,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
Args:
......@@ -104,6 +109,7 @@ class Parser(parser.Parser):
data augmentation during training.
aug_scale_max: `float`, the maximum scale applied to `output_size` for
data augmentation during training.
aug_type: An optional Augmentation object with params for AutoAugment.
sigma: `float`, standard deviation for generating 2D Gaussian to encode
centers.
small_instance_area_threshold: `int`,
......@@ -124,6 +130,18 @@ class Parser(parser.Parser):
self._aug_scale_min = aug_scale_min
self._aug_scale_max = aug_scale_max
if aug_type:
if aug_type.type == 'autoaug':
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
# dtype.
self._dtype = dtype
......@@ -134,7 +152,6 @@ class Parser(parser.Parser):
self._small_instance_area_threshold = small_instance_area_threshold
self._small_instance_weight = small_instance_weight
def _resize_and_crop_mask(self, mask, image_info, is_training):
"""Resizes and crops mask using `image_info` dict"""
height = image_info[0][0]
......@@ -167,6 +184,10 @@ class Parser(parser.Parser):
def _parse_data(self, data, is_training):
image = data['image']
if self._augmenter is not None and is_training:
image = self._augmenter.distort(image)
image = preprocess_ops.normalize_image(image)
category_mask = tf.cast(
......
......@@ -102,6 +102,7 @@ class PanopticDeeplabTask(base_task.Task):
aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max,
aug_rand_hflip=params.parser.aug_rand_hflip,
aug_type=params.parser.aug_type,
sigma=params.parser.sigma,
dtype=params.parser.dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment