Commit 8fa62b84 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change to image classification.

PiperOrigin-RevId: 368957441
parent c2e19c97
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
# Lint as: python3 # Lint as: python3
"""Common configurations.""" """Common configurations."""
from typing import Optional
# Import libraries # Import libraries
import dataclasses import dataclasses
...@@ -23,6 +24,37 @@ from official.core import config_definitions as cfg ...@@ -23,6 +24,37 @@ from official.core import config_definitions as cfg
from official.modeling import hyperparams from official.modeling import hyperparams
@dataclasses.dataclass
class RandAugment(hyperparams.Config):
"""Configuration for RandAugment."""
num_layers: int = 2
magnitude: float = 10
cutout_const: float = 40
translate_const: float = 10
@dataclasses.dataclass
class AutoAugment(hyperparams.Config):
"""Configuration for AutoAugment."""
augmentation_name: str = 'v0'
cutout_const: float = 100
translate_const: float = 250
@dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig):
"""Configuration for input data augmentation.
Attributes:
type: 'str', type of augmentation be used, one of the fields below.
randaug: RandAugment config.
autoaug: AutoAugment config.
"""
type: Optional[str] = None
randaug: RandAugment = RandAugment()
autoaug: AutoAugment = AutoAugment()
@dataclasses.dataclass @dataclasses.dataclass
class NormActivation(hyperparams.Config): class NormActivation(hyperparams.Config):
activation: str = 'relu' activation: str = 'relu'
...@@ -35,5 +67,8 @@ class NormActivation(hyperparams.Config): ...@@ -35,5 +67,8 @@ class NormActivation(hyperparams.Config):
class PseudoLabelDataConfig(cfg.DataConfig): class PseudoLabelDataConfig(cfg.DataConfig):
"""Psuedo Label input config for training.""" """Psuedo Label input config for training."""
input_path: str = '' input_path: str = ''
data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data data_ratio: float = 1.0 # Per-batch ratio of pseudo-labeled to labeled data.
aug_rand_hflip: bool = True
aug_type: Optional[
Augmentation] = None # Choose from AutoAugment and RandAugment.
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
...@@ -12,19 +12,19 @@ task: ...@@ -12,19 +12,19 @@ task:
model_id: 50 model_id: 50
losses: losses:
l2_weight_decay: 0.0001 l2_weight_decay: 0.0001
one_hot: True one_hot: true
label_smoothing: 0.1 label_smoothing: 0.1
train_data: train_data:
input_path: 'imagenet-2012-tfrecord/train*' input_path: 'imagenet-2012-tfrecord/train*'
is_training: True is_training: true
global_batch_size: 2048 global_batch_size: 2048
dtype: 'float16' dtype: 'float16'
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: False is_training: false
global_batch_size: 2048 global_batch_size: 2048
dtype: 'float16' dtype: 'float16'
drop_remainder: False drop_remainder: false
trainer: trainer:
train_steps: 56160 train_steps: 56160
validation_steps: 25 validation_steps: 25
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -28,8 +28,10 @@ task: ...@@ -28,8 +28,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 15 type: 'randaug'
randaug:
magnitude: 15
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -29,8 +29,10 @@ task: ...@@ -29,8 +29,10 @@ task:
is_training: true is_training: true
global_batch_size: 4096 global_batch_size: 4096
dtype: 'bfloat16' dtype: 'bfloat16'
aug_policy: 'randaug' aug_type:
randaug_magnitude: 10 type: 'randaug'
randaug:
magnitude: 10
validation_data: validation_data:
input_path: 'imagenet-2012-tfrecord/valid*' input_path: 'imagenet-2012-tfrecord/valid*'
is_training: false is_training: false
......
...@@ -34,12 +34,17 @@ class DataConfig(cfg.DataConfig): ...@@ -34,12 +34,17 @@ class DataConfig(cfg.DataConfig):
dtype: str = 'float32' dtype: str = 'float32'
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
cycle_length: int = 10 cycle_length: int = 10
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug' aug_rand_hflip: bool = True
randaug_magnitude: Optional[int] = 10 aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label' label_field_key: str = 'image/class/label'
# Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
randaug_magnitude: Optional[int] = 10
@dataclasses.dataclass @dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config): class ImageClassificationModel(hyperparams.Config):
...@@ -198,8 +203,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig: ...@@ -198,8 +203,8 @@ def image_classification_imagenet_resnetrs() -> cfg.ExperimentConfig:
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True, is_training=True,
global_batch_size=train_batch_size, global_batch_size=train_batch_size,
aug_policy='randaug', aug_type=common.Augmentation(
randaug_magnitude=10), type='randaug', randaug=common.RandAugment(magnitude=10))),
validation_data=DataConfig( validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'), input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False, is_training=False,
......
...@@ -17,6 +17,7 @@ from typing import Dict, List, Optional ...@@ -17,6 +17,7 @@ from typing import Dict, List, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.configs import common
from official.vision.beta.dataloaders import decoder from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import augment from official.vision.beta.ops import augment
...@@ -52,8 +53,7 @@ class Parser(parser.Parser): ...@@ -52,8 +53,7 @@ class Parser(parser.Parser):
image_field_key: str = 'image/encoded', image_field_key: str = 'image/encoded',
label_field_key: str = 'image/class/label', label_field_key: str = 'image/class/label',
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_policy: Optional[str] = None, aug_type: Optional[common.Augmentation] = None,
randaug_magnitude: Optional[int] = 10,
dtype: str = 'float32'): dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
...@@ -65,8 +65,8 @@ class Parser(parser.Parser): ...@@ -65,8 +65,8 @@ class Parser(parser.Parser):
label_field_key: A `str` of the key name to label in TFExample. label_field_key: A `str` of the key name to label in TFExample.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip. horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'. aug_type: An optional Augmentation object to choose from AutoAugment and
randaug_magnitude: `int`, magnitude of the randaugment policy. RandAugment.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16', dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'. or 'bfloat16'.
""" """
...@@ -84,15 +84,21 @@ class Parser(parser.Parser): ...@@ -84,15 +84,21 @@ class Parser(parser.Parser):
self._dtype = tf.bfloat16 self._dtype = tf.bfloat16
else: else:
raise ValueError('dtype {!r} is not supported!'.format(dtype)) raise ValueError('dtype {!r} is not supported!'.format(dtype))
if aug_policy: if aug_type:
if aug_policy == 'autoaug': if aug_type.type == 'autoaug':
self._augmenter = augment.AutoAugment() self._augmenter = augment.AutoAugment(
elif aug_policy == 'randaug': augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
elif aug_type.type == 'randaug':
self._augmenter = augment.RandAugment( self._augmenter = augment.RandAugment(
num_layers=2, magnitude=randaug_magnitude) num_layers=aug_type.randaug.num_layers,
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const)
else: else:
raise ValueError( raise ValueError('Augmentation policy {} not supported.'.format(
'Augmentation policy {} not supported.'.format(aug_policy)) aug_type.type))
else: else:
self._augmenter = None self._augmenter = None
......
...@@ -100,8 +100,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -100,8 +100,8 @@ class ImageClassificationTask(base_task.Task):
num_classes=num_classes, num_classes=num_classes,
image_field_key=image_field_key, image_field_key=image_field_key,
label_field_key=label_field_key, label_field_key=label_field_key,
aug_policy=params.aug_policy, aug_rand_hflip=params.aug_rand_hflip,
randaug_magnitude=params.randaug_magnitude, aug_type=params.aug_type,
dtype=params.dtype) dtype=params.dtype)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment