vit.py 10.3 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Xianzhi Du's avatar
Xianzhi Du committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""VisionTransformer models."""
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
16
17

import immutabledict
Xianzhi Du's avatar
Xianzhi Du committed
18
19
20
import tensorflow as tf

from official.modeling import activations
Xianzhi Du's avatar
Xianzhi Du committed
21
from official.projects.vit.modeling import nn_blocks
Xianzhi Du's avatar
Xianzhi Du committed
22
23
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_layers
Xianzhi Du's avatar
Xianzhi Du committed
24

Frederick Liu's avatar
Frederick Liu committed
25

Xianzhi Du's avatar
Xianzhi Du committed
26
27
layers = tf.keras.layers

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
28
VIT_SPECS = immutabledict.immutabledict({
29
    'vit-ti16':
Xianzhi Du's avatar
Xianzhi Du committed
30
        dict(
31
            hidden_size=192,
Xianzhi Du's avatar
Xianzhi Du committed
32
            patch_size=16,
33
            transformer=dict(mlp_dim=768, num_heads=3, num_layers=12),
34
35
36
37
38
        ),
    'vit-s16':
        dict(
            hidden_size=384,
            patch_size=16,
39
            transformer=dict(mlp_dim=1536, num_heads=6, num_layers=12),
Xianzhi Du's avatar
Xianzhi Du committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
        ),
    'vit-b16':
        dict(
            hidden_size=768,
            patch_size=16,
            transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
        ),
    'vit-b32':
        dict(
            hidden_size=768,
            patch_size=32,
            transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
        ),
    'vit-l16':
        dict(
            hidden_size=1024,
            patch_size=16,
            transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
        ),
    'vit-l32':
        dict(
            hidden_size=1024,
            patch_size=32,
            transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
        ),
    'vit-h14':
        dict(
            hidden_size=1280,
            patch_size=14,
            transformer=dict(mlp_dim=5120, num_heads=16, num_layers=32),
        ),
Xianzhi Du's avatar
Xianzhi Du committed
71
72
73
74
75
76
    'vit-g14':
        dict(
            hidden_size=1664,
            patch_size=14,
            transformer=dict(mlp_dim=8192, num_heads=16, num_layers=48),
        ),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
77
})
Xianzhi Du's avatar
Xianzhi Du committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123


class AddPositionEmbs(tf.keras.layers.Layer):
  """Adds (optionally learned) positional embeddings to the inputs."""

  def __init__(self, posemb_init=None, **kwargs):
    super().__init__(**kwargs)
    self.posemb_init = posemb_init

  def build(self, inputs_shape):
    pos_emb_shape = (1, inputs_shape[1], inputs_shape[2])
    self.pos_embedding = self.add_weight(
        'pos_embedding', pos_emb_shape, initializer=self.posemb_init)

  def call(self, inputs, inputs_positions=None):
    # inputs.shape is (batch_size, seq_len, emb_dim).
    pos_embedding = tf.cast(self.pos_embedding, inputs.dtype)

    return inputs + pos_embedding


class TokenLayer(tf.keras.layers.Layer):
  """A simple layer to wrap token parameters."""

  def build(self, inputs_shape):
    self.cls = self.add_weight(
        'cls', (1, 1, inputs_shape[-1]), initializer='zeros')

  def call(self, inputs):
    cls = tf.cast(self.cls, inputs.dtype)
    cls = cls + tf.zeros_like(inputs[:, 0:1])  # A hacky way to tile.
    x = tf.concat([cls, inputs], axis=1)
    return x


class Encoder(tf.keras.layers.Layer):
  """Transformer Encoder."""

  def __init__(self,
               num_layers,
               mlp_dim,
               num_heads,
               dropout_rate=0.1,
               attention_dropout_rate=0.1,
               kernel_regularizer=None,
               inputs_positions=None,
124
125
               init_stochastic_depth_rate=0.0,
               kernel_initializer='glorot_uniform',
Frederick Liu's avatar
Frederick Liu committed
126
               add_pos_embed=True,
Xianzhi Du's avatar
Xianzhi Du committed
127
128
129
130
131
132
133
134
135
               **kwargs):
    super().__init__(**kwargs)
    self._num_layers = num_layers
    self._mlp_dim = mlp_dim
    self._num_heads = num_heads
    self._dropout_rate = dropout_rate
    self._attention_dropout_rate = attention_dropout_rate
    self._kernel_regularizer = kernel_regularizer
    self._inputs_positions = inputs_positions
136
137
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
    self._kernel_initializer = kernel_initializer
Frederick Liu's avatar
Frederick Liu committed
138
    self._add_pos_embed = add_pos_embed
Xianzhi Du's avatar
Xianzhi Du committed
139
140

  def build(self, input_shape):
Frederick Liu's avatar
Frederick Liu committed
141
142
143
144
    if self._add_pos_embed:
      self._pos_embed = AddPositionEmbs(
          posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
          name='posembed_input')
Xianzhi Du's avatar
Xianzhi Du committed
145
146
147
148
    self._dropout = layers.Dropout(rate=self._dropout_rate)

    self._encoder_layers = []
    # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
Marc van Zee's avatar
Marc van Zee committed
149
    # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
150
    for i in range(self._num_layers):
151
      encoder_layer = nn_blocks.TransformerEncoderBlock(
Xianzhi Du's avatar
Xianzhi Du committed
152
153
154
155
156
157
          inner_activation=activations.gelu,
          num_attention_heads=self._num_heads,
          inner_dim=self._mlp_dim,
          output_dropout=self._dropout_rate,
          attention_dropout=self._attention_dropout_rate,
          kernel_regularizer=self._kernel_regularizer,
158
          kernel_initializer=self._kernel_initializer,
Xianzhi Du's avatar
Xianzhi Du committed
159
          norm_first=True,
160
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
161
              self._init_stochastic_depth_rate, i + 1, self._num_layers),
Xianzhi Du's avatar
Xianzhi Du committed
162
163
164
165
166
167
          norm_epsilon=1e-6)
      self._encoder_layers.append(encoder_layer)
    self._norm = layers.LayerNormalization(epsilon=1e-6)
    super().build(input_shape)

  def call(self, inputs, training=None):
Frederick Liu's avatar
Frederick Liu committed
168
169
170
    x = inputs
    if self._add_pos_embed:
      x = self._pos_embed(x, inputs_positions=self._inputs_positions)
Xianzhi Du's avatar
Xianzhi Du committed
171
172
173
174
175
176
177
    x = self._dropout(x, training=training)

    for encoder_layer in self._encoder_layers:
      x = encoder_layer(x, training=training)
    x = self._norm(x)
    return x

exx8's avatar
exx8 committed
178
179
180
181
182
183
184
185
186
187
188
189
190
191
  def get_config(self):
    config = {
        'num_layers': self._num_layers,
        'mlp_dim': self._mlp_dim,
        'num_heads': self._num_heads,
        'dropout_rate': self._dropout_rate,
        'attention_dropout_rate': self._attention_dropout_rate,
        'kernel_regularizer': self._kernel_regularizer,
        'inputs_positions': self._inputs_positions,
        'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
        'kernel_initializer': self._kernel_initializer,
        'add_pos_embed': self._add_pos_embed,
    }
    base_config = super().get_config()
192
    return base_config.update(config)
Xianzhi Du's avatar
Xianzhi Du committed
193

Xianzhi Du's avatar
Xianzhi Du committed
194
195
196
197
198
199
200
201
202
203

class VisionTransformer(tf.keras.Model):
  """Class to build VisionTransformer family model."""

  def __init__(self,
               mlp_dim=3072,
               num_heads=12,
               num_layers=12,
               attention_dropout_rate=0.0,
               dropout_rate=0.1,
204
               init_stochastic_depth_rate=0.0,
Xianzhi Du's avatar
Xianzhi Du committed
205
206
207
208
209
               input_specs=layers.InputSpec(shape=[None, None, None, 3]),
               patch_size=16,
               hidden_size=768,
               representation_size=0,
               classifier='token',
210
211
               kernel_regularizer=None,
               original_init=True):
Xianzhi Du's avatar
Xianzhi Du committed
212
213
214
215
216
217
218
219
    """VisionTransformer initialization function."""
    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    x = layers.Conv2D(
        filters=hidden_size,
        kernel_size=patch_size,
        strides=patch_size,
        padding='valid',
220
221
        kernel_regularizer=kernel_regularizer,
        kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
Xianzhi Du's avatar
Xianzhi Du committed
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
            inputs)
    if tf.keras.backend.image_data_format() == 'channels_last':
      rows_axis, cols_axis = (1, 2)
    else:
      rows_axis, cols_axis = (2, 3)
      # The reshape below assumes the data_format is 'channels_last,' so
      # transpose to that. Once the data is flattened by the reshape, the
      # data_format is irrelevant, so no need to update
      # tf.keras.backend.image_data_format.
      x = tf.transpose(x, perm=[0, 2, 3, 1])
    seq_len = (input_specs.shape[rows_axis] // patch_size) * (
        input_specs.shape[cols_axis] // patch_size)
    x = tf.reshape(x, [-1, seq_len, hidden_size])

    # If we want to add a class token, add it here.
    if classifier == 'token':
      x = TokenLayer(name='cls')(x)

    x = Encoder(
        num_layers=num_layers,
        mlp_dim=mlp_dim,
        num_heads=num_heads,
        dropout_rate=dropout_rate,
        attention_dropout_rate=attention_dropout_rate,
246
247
248
249
        kernel_regularizer=kernel_regularizer,
        kernel_initializer='glorot_uniform' if original_init else dict(
            class_name='TruncatedNormal', config=dict(stddev=.02)),
        init_stochastic_depth_rate=init_stochastic_depth_rate)(
Xianzhi Du's avatar
Xianzhi Du committed
250
251
252
253
254
255
256
257
258
259
260
            x)

    if classifier == 'token':
      x = x[:, 0]
    elif classifier == 'gap':
      x = tf.reduce_mean(x, axis=1)

    if representation_size:
      x = tf.keras.layers.Dense(
          representation_size,
          kernel_regularizer=kernel_regularizer,
261
262
          name='pre_logits',
          kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
Xianzhi Du's avatar
Xianzhi Du committed
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
              x)
      x = tf.nn.tanh(x)
    else:
      x = tf.identity(x, name='pre_logits')
    endpoints = {
        'pre_logits':
            tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
    }

    super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)


@factory.register_backbone_builder('vit')
def build_vit(input_specs,
              backbone_config,
              norm_activation_config,
              l2_regularizer=None):
  """Build ViT model."""
  del norm_activation_config
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
  assert backbone_type == 'vit', (f'Inconsistent backbone type '
                                  f'{backbone_type}')
  backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])

  return VisionTransformer(
      mlp_dim=backbone_cfg.transformer.mlp_dim,
      num_heads=backbone_cfg.transformer.num_heads,
      num_layers=backbone_cfg.transformer.num_layers,
      attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
      dropout_rate=backbone_cfg.transformer.dropout_rate,
294
      init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
Xianzhi Du's avatar
Xianzhi Du committed
295
296
297
298
299
      input_specs=input_specs,
      patch_size=backbone_cfg.patch_size,
      hidden_size=backbone_cfg.hidden_size,
      representation_size=backbone_cfg.representation_size,
      classifier=backbone_cfg.classifier,
300
301
      kernel_regularizer=l2_regularizer,
      original_init=backbone_cfg.original_init)