simclr_model.py 5.14 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Build simclr models."""
from typing import Optional
from absl import logging

import tensorflow as tf

layers = tf.keras.layers

PRETRAIN = 'pretrain'
FINETUNE = 'finetune'

PROJECTION_OUTPUT_KEY = 'projection_outputs'
SUPERVISED_OUTPUT_KEY = 'supervised_outputs'


class SimCLRModel(tf.keras.Model):
  """A classification model based on SimCLR framework."""

  def __init__(self,
               backbone: tf.keras.models.Model,
               projection_head: tf.keras.layers.Layer,
               supervised_head: Optional[tf.keras.layers.Layer] = None,
               input_specs=layers.InputSpec(shape=[None, None, None, 3]),
               mode: str = PRETRAIN,
               backbone_trainable: bool = True,
               **kwargs):
    """A classification model based on SimCLR framework.

    Args:
      backbone: a backbone network.
      projection_head: a projection head network.
      supervised_head: a head network for supervised learning, e.g.
        classification head.
      input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
      mode: `str` indicates mode of training to be executed.
      backbone_trainable: `bool` whether the backbone is trainable or not.
      **kwargs: keyword arguments to be passed.
    """
    super(SimCLRModel, self).__init__(**kwargs)
    self._config_dict = {
        'backbone': backbone,
        'projection_head': projection_head,
        'supervised_head': supervised_head,
        'input_specs': input_specs,
        'mode': mode,
        'backbone_trainable': backbone_trainable,
    }
    self._input_specs = input_specs
    self._backbone = backbone
    self._projection_head = projection_head
    self._supervised_head = supervised_head
    self._mode = mode
    self._backbone_trainable = backbone_trainable

    # Set whether the backbone is trainable
    self._backbone.trainable = backbone_trainable

  def call(self, inputs, training=None, **kwargs):
    model_outputs = {}

    if training and self._mode == PRETRAIN:
      num_transforms = 2
77
78
79
80
81
82
      # Split channels, and optionally apply extra batched augmentation.
      # (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
      features_list = tf.split(
          inputs, num_or_size_splits=num_transforms, axis=-1)
      # (num_transforms * bsz, h, w, c)
      features = tf.concat(features_list, 0)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
83
84
    else:
      num_transforms = 1
85
      features = inputs
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
86
87

    # Base network forward pass.
Abdullah Rashwan's avatar
Abdullah Rashwan committed
88
89
    endpoints = self._backbone(
        features, training=training and self._backbone_trainable)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    features = endpoints[max(endpoints.keys())]
    projection_inputs = layers.GlobalAveragePooling2D()(features)

    # Add heads.
    projection_outputs, supervised_inputs = self._projection_head(
        projection_inputs, training)

    if self._supervised_head is not None:
      if self._mode == PRETRAIN:
        logging.info('Ignoring gradient from supervised outputs !')
        # When performing pretraining and supervised_head together, we do not
        # want information from supervised evaluation flowing back into
        # pretraining network. So we put a stop_gradient.
        supervised_outputs = self._supervised_head(
            tf.stop_gradient(supervised_inputs), training)
      else:
        supervised_outputs = self._supervised_head(supervised_inputs, training)
    else:
      supervised_outputs = None

    model_outputs.update({
        PROJECTION_OUTPUT_KEY: projection_outputs,
        SUPERVISED_OUTPUT_KEY: supervised_outputs
    })

    return model_outputs

  @property
  def checkpoint_items(self):
    """Returns a dictionary of items to be additionally checkpointed."""
    if self._supervised_head is not None:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
121
122
123
124
      items = dict(
          backbone=self.backbone,
          projection_head=self.projection_head,
          supervised_head=self.supervised_head)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
125
    else:
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
126
      items = dict(backbone=self.backbone, projection_head=self.projection_head)
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    return items

  @property
  def backbone(self):
    return self._backbone

  @property
  def projection_head(self):
    return self._projection_head

  @property
  def supervised_head(self):
    return self._supervised_head

  @property
  def mode(self):
    return self._mode

  @mode.setter
  def mode(self, value):
    self._mode = value

  @property
  def backbone_trainable(self):
    return self._backbone_trainable

  @backbone_trainable.setter
  def backbone_trainable(self, value):
    self._backbone_trainable = value
    self._backbone.trainable = value

  def get_config(self):
    return self._config_dict

  @classmethod
  def from_config(cls, config, custom_objects=None):
    return cls(**config)