Commit 7797ebad authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 388565418
parent 194d47c6
runtime:
distribution_strategy: tpu
mixed_precision_dtype: 'bfloat16'
task:
init_checkpoint: ''
model:
backbone:
resnet:
model_id: 50
type: resnet
projection_head:
ft_proj_idx: 1
num_proj_layers: 3
proj_output_dim: 128
backbone_trainable: true
heads: !!python/tuple
# Define heads for the PRETRAIN networks here
- task_name: pretrain_imagenet
mode: pretrain
# # Define heads for the FINETUNE networks here
- task_name: finetune_imagenet_10percent
mode: finetune
supervised_head:
num_classes: 1001
zero_init: true
input_size: [224, 224, 3]
l2_weight_decay: 0.0
norm_activation:
norm_epsilon: 1.0e-05
norm_momentum: 0.9
use_sync_bn: true
task_routines: !!python/tuple
# Define TASK CONFIG for the PRETRAIN networks here
- task_name: pretrain_imagenet
task_weight: 30.0
task_config:
evaluation:
one_hot: true
top_k: 5
loss:
l2_weight_decay: 0.0
projection_norm: true
temperature: 0.1
model:
input_size: [224, 224, 3]
mode: pretrain
train_data:
input_path: /readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*
input_set_label_to_zero: true # Set labels to zeros to double confirm that no label is used during pretrain
is_training: true
global_batch_size: 4096
dtype: 'bfloat16'
parser:
aug_rand_hflip: true
mode: pretrain
decoder:
decode_label: true
validation_data:
input_path: /readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*
is_training: false
global_batch_size: 2048
dtype: 'bfloat16'
drop_remainder: false
parser:
mode: pretrain
decoder:
decode_label: true
# Define TASK CONFIG for the FINETUNE Networks here
- task_name: finetune_imagenet_10percent
task_weight: 1.0
task_config:
evaluation:
one_hot: true
top_k: 5
loss:
l2_weight_decay: 0.0
label_smoothing: 0.0
one_hot: true
model:
input_size: [224, 224, 3]
mode: finetune
supervised_head:
num_classes: 1001
zero_init: true
train_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'train'
input_path: ''
is_training: true
global_batch_size: 1024
dtype: 'bfloat16'
parser:
aug_rand_hflip: true
mode: finetune
decoder:
decode_label: true
validation_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'validation'
input_path: ''
is_training: false
global_batch_size: 2048
dtype: 'bfloat16'
drop_remainder: false
parser:
mode: finetune
decoder:
decode_label: true
trainer:
trainer_type: interleaving
task_sampler:
proportional:
alpha: 1.0
type: proportional
train_steps: 32000 # 100 epochs
validation_steps: 24 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 625
steps_per_loop: 625 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 625
checkpoint_interval: 625
max_to_keep: 3
optimizer_config:
learning_rate:
cosine:
decay_steps: 32000
initial_learning_rate: 4.8
type: cosine
optimizer:
lars:
exclude_from_weight_decay: [batch_normalization, bias]
momentum: 0.9
weight_decay_rate: 1.0e-06
type: lars
warmup:
linear:
name: linear
warmup_steps: 3200
type: linear
......@@ -29,6 +29,7 @@ from official.vision.beta.projects.simclr.modeling import simclr_model
@dataclasses.dataclass
class SimCLRMTHeadConfig(hyperparams.Config):
"""Per-task specific configs."""
task_name: str = 'task_name'
# Supervised head is required for finetune, but optional for pretrain.
supervised_head: simclr_configs.SupervisedHead = simclr_configs.SupervisedHead(
num_classes=1001)
......@@ -57,14 +58,17 @@ def multitask_simclr() -> multitask_configs.MultiTaskExperimentConfig:
return multitask_configs.MultiTaskExperimentConfig(
task=multitask_configs.MultiTaskConfig(
model=SimCLRMTModelConfig(
heads=(SimCLRMTHeadConfig(mode=simclr_model.PRETRAIN),
SimCLRMTHeadConfig(mode=simclr_model.FINETUNE))),
heads=(SimCLRMTHeadConfig(
task_name='pretrain_simclr', mode=simclr_model.PRETRAIN),
SimCLRMTHeadConfig(
task_name='finetune_simclr',
mode=simclr_model.FINETUNE))),
task_routines=(multitask_configs.TaskRoutine(
task_name=simclr_model.PRETRAIN,
task_name='pretrain_simclr',
task_config=simclr_configs.SimCLRPretrainTask(),
task_weight=2.0),
multitask_configs.TaskRoutine(
task_name=simclr_model.FINETUNE,
task_name='finetune_simclr',
task_config=simclr_configs.SimCLRFinetuneTask(),
task_weight=1.0))),
trainer=multitask_configs.MultiTaskTrainerConfig())
......@@ -14,10 +14,9 @@
"""SimCLR configurations."""
import dataclasses
import os.path
import os
from typing import List, Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
......@@ -58,6 +57,9 @@ class DataConfig(cfg.DataConfig):
# simclr specific configs
parser: Parser = Parser()
decoder: Decoder = Decoder()
# Useful when doing a sanity check that we absolutely use no labels while
# pretrain by setting labels to zeros (default = False, keep original labels)
input_set_label_to_zero: bool = False
@dataclasses.dataclass
......
......@@ -84,7 +84,7 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
else:
supervised_head = None
tasks[model_config.mode] = simclr_model.SimCLRModel(
tasks[model_config.task_name] = simclr_model.SimCLRModel(
input_specs=self._input_specs,
backbone=self._backbone,
projection_head=projection_head,
......
......@@ -29,11 +29,13 @@ class MultitaskModelTest(tf.test.TestCase):
ckpt_dir = self.get_temp_dir()
config = multitask_config.SimCLRMTModelConfig(
input_size=[64, 64, 3],
heads=(multitask_config.SimCLRMTHeadConfig(mode=simclr_model.PRETRAIN),
multitask_config.SimCLRMTHeadConfig(mode=simclr_model.FINETUNE)))
heads=(multitask_config.SimCLRMTHeadConfig(
mode=simclr_model.PRETRAIN, task_name='pretrain_simclr'),
multitask_config.SimCLRMTHeadConfig(
mode=simclr_model.FINETUNE, task_name='finetune_simclr')))
model = multitask_model.SimCLRMTModel(config)
self.assertIn(simclr_model.PRETRAIN, model.sub_tasks)
self.assertIn(simclr_model.FINETUNE, model.sub_tasks)
self.assertIn('pretrain_simclr', model.sub_tasks)
self.assertIn('finetune_simclr', model.sub_tasks)
ckpt = tf.train.Checkpoint(backbone=model._backbone)
ckpt.save(os.path.join(ckpt_dir, 'ckpt'))
model.initialize()
......
......@@ -287,6 +287,12 @@ class SimCLRPretrainTask(base_task.Task):
def train_step(self, inputs, model, optimizer, metrics=None):
features, labels = inputs
# To do a sanity check that we absolutely use no labels when pretraining, we
# can set the labels here to zero.
if self.task_config.train_data.input_set_label_to_zero:
labels *= 0
if (self.task_config.model.supervised_head is not None and
self.task_config.evaluation.one_hot):
num_classes = self.task_config.model.supervised_head.num_classes
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment