Commit c52a287f authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change to image classification.

PiperOrigin-RevId: 368957441
parent c7823825
......@@ -15,6 +15,7 @@
# Lint as: python3
"""Common configurations."""
from typing import Optional
# Import libraries
import dataclasses
......@@ -23,6 +24,37 @@ from official.core import config_definitions as cfg
from official.modeling import hyperparams
@dataclasses.dataclass
class RandAugment(hyperparams.Config):
"""Configuration for RandAugment."""
num_layers: int = 2
magnitude: float = 10
cutout_const: float = 40
translate_const: float = 10
@dataclasses.dataclass
class AutoAugment(hyperparams.Config):
"""Configuration for AutoAugment."""
augmentation_name: str = 'v0'
cutout_const: float = 100
translate_const: float = 250
@dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig):
"""Configuration for input data augmentation.
Attributes:
type: 'str', type of augmentation be used, one of the fields below.
randaug: RandAugment config.
autoaug: AutoAugment config.
"""
type: Optional[str] = None
randaug: RandAugment = RandAugment()
autoaug: AutoAugment = AutoAugment()
@dataclasses.dataclass
class NormActivation(hyperparams.Config):
activation: str = 'relu'
......@@ -35,5 +67,8 @@ class NormActivation(hyperparams.Config):
class PseudoLabelDataConfig(cfg.DataConfig):
"""Psuedo Label input config for training."""
input_path: str = ''
data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data
data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data.
aug_rand_hflip: bool = True
aug_type: Optional[
Augmentation] = None # Choose from AutoAugment and RandAugment.
file_type: str = 'tfrecord'
......@@ -12,19 +12,19 @@ task:
model_id: 50
losses:
l2_weight_decay: 0.0001
one_hot: True
one_hot: true
label_smoothing: 0.1
train_data:
input_path: 'imagenet-2012-tfrecord/train*'
is_training: True
is_training: true
global_batch_size: 2048
dtype: 'float16'
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: False
is_training: false
global_batch_size: 2048
dtype: 'float16'
drop_remainder: False
drop_remainder: false
trainer:
train_steps: 56160
validation_steps: 25
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -28,8 +28,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 15
aug_type:
type: 'randaug'
randaug:
magnitude: 15
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -29,8 +29,10 @@ task:
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
aug_policy: 'randaug'
randaug_magnitude: 10
aug_type:
type: 'randaug'
randaug:
magnitude: 10
validation_data:
input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false
......
......@@ -34,12 +34,17 @@ class DataConfig(cfg.DataConfig):
dtype: str = 'float32'
shuffle_buffer_size: int = 10000
cycle_length: int = 10
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'
randaug_magnitude: Optional[int] = 10
aug_rand_hflip: bool = True
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
# Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
randaug_magnitude: Optional[int] = 10
@dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config):
......@@ -198,8 +203,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_policy='randaug',
randaug_magnitude=10),
aug_type=common.Augmentation(
type='randaug', randaug=common.RandAugment(magnitude=10))),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
......
......@@ -17,6 +17,7 @@ from typing import Dict, List, Optional
# Import libraries
import tensorflow as tf
from official.vision.beta.configs import common
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import augment
......@@ -52,8 +53,7 @@ class Parser(parser.Parser):
image_field_key: str = 'image/encoded',
label_field_key: str = 'image/class/label',
aug_rand_hflip: bool = True,
aug_policy: Optional[str] = None,
randaug_magnitude: Optional[int] = 10,
aug_type: Optional[common.Augmentation] = None,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -65,8 +65,8 @@ class Parser(parser.Parser):
label_field_key: A `str` of the key name to label in TFExample.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
randaug_magnitude: `int`, magnitude of the randaugment policy.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
......@@ -84,15 +84,21 @@ class Parser(parser.Parser):
self._dtype = tf.bfloat16
else:
raise ValueError('dtype {!r} is not supported!'.format(dtype))
if aug_policy:
if aug_policy == 'autoaug':
self._augmenter = augment.AutoAugment()
elif aug_policy == 'randaug':
if aug_type:
if aug_type.type == 'autoaug':
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
self._augmenter = augment.RandAugment(
num_layers=2, magnitude=randaug_magnitude)
num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const)
else:
raise ValueError(
'Augmentation policy {} not supported.'.format(aug_policy))
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
......
......@@ -100,8 +100,8 @@ class ImageClassificationTask(base_task.Task):
num_classes=num_classes,
image_field_key=image_field_key,
label_field_key=label_field_key,
aug_policy=params.aug_policy,
randaug_magnitude=params.randaug_magnitude,
aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type,
dtype=params.dtype)
reader = input_reader_factory.input_reader_generator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment