Commit a91f3779 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 393981019
parent bf9805c5
......@@ -51,6 +51,9 @@ class SimCLRMTModelConfig(hyperparams.Config):
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
l2_weight_decay: float = 0.0
init_checkpoint: str = ''
# backbone_projection or backbone
init_checkpoint_modules: str = 'backbone_projection'
@exp_factory.register_config_factory('multitask_simclr')
......
......@@ -14,6 +14,7 @@
"""Multi-task image multi-taskSimCLR model definition."""
from typing import Dict, Text
from absl import logging
import tensorflow as tf
......@@ -52,15 +53,10 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_activation_config=config.norm_activation,
l2_regularizer=self._l2_regularizer)
super().__init__(**kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
tasks = {}
# Build the shared projection head
norm_activation_config = self._config.norm_activation
projection_head_config = self._config.projection_head
projection_head = simclr_head.ProjectionHead(
self._projection_head = simclr_head.ProjectionHead(
proj_output_dim=projection_head_config.proj_output_dim,
num_proj_layers=projection_head_config.num_proj_layers,
ft_proj_idx=projection_head_config.ft_proj_idx,
......@@ -69,6 +65,11 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_momentum=norm_activation_config.norm_momentum,
norm_epsilon=norm_activation_config.norm_epsilon)
super().__init__(**kwargs)
def _instantiate_sub_tasks(self) -> Dict[Text, tf.keras.Model]:
tasks = {}
for model_config in self._config.heads:
# Build supervised head
supervised_head_config = model_config.supervised_head
......@@ -87,13 +88,38 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
tasks[model_config.task_name] = simclr_model.SimCLRModel(
input_specs=self._input_specs,
backbone=self._backbone,
projection_head=projection_head,
projection_head=self._projection_head,
supervised_head=supervised_head,
mode=model_config.mode,
backbone_trainable=self._config.backbone_trainable)
return tasks
# TODO(huythong): Implement initialize function to load the pretrained
# checkpoint of backbone.
# def initialize(self):
def initialize(self):
"""Loads the multi-task SimCLR model with a pretrained checkpoint."""
ckpt_dir_or_file = self._config.init_checkpoint
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if not ckpt_dir_or_file:
return
logging.info('Loading pretrained %s', self._config.init_checkpoint_modules)
if self._config.init_checkpoint_modules == 'backbone':
pretrained_items = dict(backbone=self._backbone)
elif self._config.init_checkpoint_modules == 'backbone_projection':
pretrained_items = dict(
backbone=self._backbone, projection_head=self._projection_head)
else:
assert ("Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.')
ckpt = tf.train.Checkpoint(**pretrained_items)
status = ckpt.read(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
@property
def checkpoint_items(self):
"""Returns a dictionary of items to be additionally checkpointed."""
return dict(backbone=self._backbone, projection_head=self._projection_head)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment