Commit 9465aa0e authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 387142853
parent 2de518be
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-task SimCLR configs."""
import dataclasses
from typing import List, Tuple
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling.multitask import configs as multitask_configs
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.configs import simclr as simclr_configs
from official.vision.beta.projects.simclr.modeling import simclr_model
@dataclasses.dataclass
class SimCLRMTHeadConfig(hyperparams.Config):
"""Per-task specific configs."""
# Supervised head is required for finetune, but optional for pretrain.
supervised_head: simclr_configs.SupervisedHead = simclr_configs.SupervisedHead(
num_classes=1001)
mode: str = simclr_model.PRETRAIN
@dataclasses.dataclass
class SimCLRMTModelConfig(hyperparams.Config):
"""Model config for multi-task SimCLR model."""
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
backbone_trainable: bool = True
projection_head: simclr_configs.ProjectionHead = simclr_configs.ProjectionHead(
proj_output_dim=128, num_proj_layers=3, ft_proj_idx=1)
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
heads: Tuple[SimCLRMTHeadConfig, ...] = ()
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
l2_weight_decay: float = 0.0
@exp_factory.register_config_factory('multitask_simclr')
def multitask_simclr() -> multitask_configs.MultiTaskExperimentConfig:
return multitask_configs.MultiTaskExperimentConfig(
task=multitask_configs.MultiTaskConfig(
model=SimCLRMTModelConfig(
heads=(SimCLRMTHeadConfig(mode=simclr_model.PRETRAIN),
SimCLRMTHeadConfig(mode=simclr_model.FINETUNE))),
task_routines=(multitask_configs.TaskRoutine(
task_name=simclr_model.PRETRAIN,
task_config=simclr_configs.SimCLRPretrainTask(),
task_weight=2.0),
multitask_configs.TaskRoutine(
task_name=simclr_model.FINETUNE,
task_config=simclr_configs.SimCLRFinetuneTask(),
task_weight=1.0))),
trainer=multitask_configs.MultiTaskTrainerConfig())
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask_config."""
import tensorflow as tf
from official.core import exp_factory
from official.modeling.multitask import configs as multitask_configs
from official.vision.beta.projects.simclr.configs import multitask_config as simclr_multitask_config
from official.vision.beta.projects.simclr.configs import simclr as exp_cfg
class MultitaskConfigTest(tf.test.TestCase):
def test_simclr_configs(self):
config = exp_factory.get_exp_config('multitask_simclr')
self.assertIsInstance(config, multitask_configs.MultiTaskExperimentConfig)
self.assertIsInstance(config.task.model,
simclr_multitask_config.SimCLRMTModelConfig)
self.assertIsInstance(config.task.task_routines[0].task_config,
exp_cfg.SimCLRPretrainTask)
self.assertIsInstance(config.task.task_routines[1].task_config,
exp_cfg.SimCLRFinetuneTask)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-task image multi-taskSimCLR model definition."""
from typing import Dict, Text
import tensorflow as tf
from official.modeling.multitask import base_model
from official.vision.beta.modeling import backbones
from official.vision.beta.projects.simclr.configs import multitask_config as simclr_multitask_config
from official.vision.beta.projects.simclr.heads import simclr_head
from official.vision.beta.projects.simclr.modeling import simclr_model
PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs'
class SimCLRMTModel(base_model.MultiTaskBaseModel):
"""A multi-task SimCLR model that does both pretrain and finetune."""
def __init__(self, config: simclr_multitask_config.SimCLRMTModelConfig,
**kwargs):
self._config = config
# Build shared backbone.
self._input_specs = tf.keras.layers.InputSpec(shape=[None] +
config.input_size)
l2_weight_decay = config.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
self._l2_regularizer = (
tf.keras.regularizers.l2(l2_weight_decay /
2.0) if l2_weight_decay else None)
self._backbone = backbones.factory.build_backbone(
input_specs=self._input_specs,
backbone_config=config.backbone,
norm_activation_config=config.norm_activation,
l2_regularizer=self._l2_regularizer)
super().__init__(**kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
tasks = {}
# Build the shared projection head
norm_activation_config = self._config.norm_activation
projection_head_config = self._config.projection_head
projection_head = simclr_head.ProjectionHead(
proj_output_dim=projection_head_config.proj_output_dim,
num_proj_layers=projection_head_config.num_proj_layers,
ft_proj_idx=projection_head_config.ft_proj_idx,
kernel_regularizer=self._l2_regularizer,
use_sync_bn=norm_activation_config.use_sync_bn,
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
for model_config in self._config.heads:
# Build supervised head
supervised_head_config = model_config.supervised_head
if supervised_head_config:
if supervised_head_config.zero_init:
s_kernel_initializer = 'zeros'
else:
s_kernel_initializer = 'random_uniform'
supervised_head = simclr_head.ClassificationHead(
num_classes=supervised_head_config.num_classes,
kernel_initializer=s_kernel_initializer,
kernel_regularizer=self._l2_regularizer)
else:
supervised_head = None
tasks[model_config.mode] = simclr_model.SimCLRModel(
input_specs=self._input_specs,
backbone=self._backbone,
projection_head=projection_head,
supervised_head=supervised_head,
mode=model_config.mode,
backbone_trainable=self._config.backbone_trainable)
return tasks
# TODO(huythong): Implement initialize function to load the pretrained
# checkpoint of backbone.
# def initialize(self):
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask_model."""
import os.path
import tensorflow as tf
from official.vision.beta.projects.simclr.configs import multitask_config
from official.vision.beta.projects.simclr.modeling import multitask_model
from official.vision.beta.projects.simclr.modeling import simclr_model
class MultitaskModelTest(tf.test.TestCase):
def test_initialize_model_success(self):
ckpt_dir = self.get_temp_dir()
config = multitask_config.SimCLRMTModelConfig(
input_size=[64, 64, 3],
heads=(multitask_config.SimCLRMTHeadConfig(mode=simclr_model.PRETRAIN),
multitask_config.SimCLRMTHeadConfig(mode=simclr_model.FINETUNE)))
model = multitask_model.SimCLRMTModel(config)
self.assertIn(simclr_model.PRETRAIN, model.sub_tasks)
self.assertIn(simclr_model.FINETUNE, model.sub_tasks)
ckpt = tf.train.Checkpoint(backbone=model._backbone)
ckpt.save(os.path.join(ckpt_dir, 'ckpt'))
model.initialize()
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer binary for multitask simclr."""
from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.core import train_utils
from official.modeling import performance
from official.modeling.multitask import multitask
from official.modeling.multitask import train_lib
# pylint: disable=unused-import
from official.vision.beta.projects.simclr.common import registry_imports
from official.vision.beta.projects.simclr.configs import multitask_config
from official.vision.beta.projects.simclr.modeling import multitask_model
# pylint: enable=unused-import
FLAGS = flags.FLAGS
def main(_):
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
params = train_utils.parse_configuration(FLAGS)
model_dir = FLAGS.model_dir
if 'train' in FLAGS.mode:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils.serialize_config(params, model_dir)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype)
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
with distribution_strategy.scope():
tasks = multitask.MultiTask.from_config(params.task)
model = multitask_model.SimCLRMTModel(params.task.model)
train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=tasks,
model=model,
mode=FLAGS.mode,
params=params,
model_dir=model_dir)
train_utils.save_gin_config(FLAGS.mode, model_dir)
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment