vit.py 13.2 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Xianzhi Du's avatar
Xianzhi Du committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""VisionTransformer models."""
16
17
18
from typing import Optional, Tuple

from absl import logging
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19
20

import immutabledict
Xianzhi Du's avatar
Xianzhi Du committed
21
22
23
import tensorflow as tf

from official.modeling import activations
Xianzhi Du's avatar
Xianzhi Du committed
24
from official.projects.vit.modeling import nn_blocks
Xianzhi Du's avatar
Xianzhi Du committed
25
26
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_layers
Xianzhi Du's avatar
Xianzhi Du committed
27

Frederick Liu's avatar
Frederick Liu committed
28

Xianzhi Du's avatar
Xianzhi Du committed
29
30
layers = tf.keras.layers

A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
31
VIT_SPECS = immutabledict.immutabledict({
32
    'vit-ti16':
Xianzhi Du's avatar
Xianzhi Du committed
33
        dict(
34
            hidden_size=192,
Xianzhi Du's avatar
Xianzhi Du committed
35
            patch_size=16,
36
            transformer=dict(mlp_dim=768, num_heads=3, num_layers=12),
37
38
39
40
41
        ),
    'vit-s16':
        dict(
            hidden_size=384,
            patch_size=16,
42
            transformer=dict(mlp_dim=1536, num_heads=6, num_layers=12),
Xianzhi Du's avatar
Xianzhi Du committed
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
        ),
    'vit-b16':
        dict(
            hidden_size=768,
            patch_size=16,
            transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
        ),
    'vit-b32':
        dict(
            hidden_size=768,
            patch_size=32,
            transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
        ),
    'vit-l16':
        dict(
            hidden_size=1024,
            patch_size=16,
            transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
        ),
    'vit-l32':
        dict(
            hidden_size=1024,
            patch_size=32,
            transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
        ),
    'vit-h14':
        dict(
            hidden_size=1280,
            patch_size=14,
            transformer=dict(mlp_dim=5120, num_heads=16, num_layers=32),
        ),
Xianzhi Du's avatar
Xianzhi Du committed
74
75
76
77
78
79
    'vit-g14':
        dict(
            hidden_size=1664,
            patch_size=14,
            transformer=dict(mlp_dim=8192, num_heads=16, num_layers=48),
        ),
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
80
})
Xianzhi Du's avatar
Xianzhi Du committed
81
82
83
84
85


class AddPositionEmbs(tf.keras.layers.Layer):
  """Adds (optionally learned) positional embeddings to the inputs."""

86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  def __init__(self,
               posemb_init: Optional[tf.keras.initializers.Initializer] = None,
               posemb_origin_shape: Optional[Tuple[int, int]] = None,
               posemb_target_shape: Optional[Tuple[int, int]] = None,
               **kwargs):
    """Constructs Postional Embedding module.

    The logic of this module is: the learnable positional embeddings length will
    be determined by the inputs_shape or posemb_origin_shape (if provided)
    during the construction. If the posemb_target_shape is provided and is
    different from the positional embeddings length, the embeddings will be
    interpolated during the forward call.

    Args:
      posemb_init: The positional embedding initializer.
      posemb_origin_shape: The intended positional embedding shape.
      posemb_target_shape: The potential target shape positional embedding may
        be interpolated to.
      **kwargs: other args.
    """
Xianzhi Du's avatar
Xianzhi Du committed
106
107
    super().__init__(**kwargs)
    self.posemb_init = posemb_init
108
109
    self.posemb_origin_shape = posemb_origin_shape
    self.posemb_target_shape = posemb_target_shape
Xianzhi Du's avatar
Xianzhi Du committed
110
111

  def build(self, inputs_shape):
112
113
114
115
116
    if self.posemb_origin_shape is not None:
      pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
    else:
      pos_emb_length = inputs_shape[1]
    pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
Xianzhi Du's avatar
Xianzhi Du committed
117
118
119
    self.pos_embedding = self.add_weight(
        'pos_embedding', pos_emb_shape, initializer=self.posemb_init)

120
121
122
123
124
125
126
127
128
129
130
131
  def _interpolate(self,
                   pos_embedding: tf.Tensor,
                   from_shape: Tuple[int, int],
                   to_shape: Tuple[int, int]) -> tf.Tensor:
    """Interpolates the positional embeddings."""
    logging.info('Interpolating postional embedding from length: %d to %d',
                 from_shape, to_shape)
    grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
    # NOTE: Using BILINEAR interpolation by default.
    grid_emb = tf.image.resize(grid_emb, to_shape)
    return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])

Xianzhi Du's avatar
Xianzhi Du committed
132
  def call(self, inputs, inputs_positions=None):
133
134
    del inputs_positions
    pos_embedding = self.pos_embedding
Xianzhi Du's avatar
Xianzhi Du committed
135
    # inputs.shape is (batch_size, seq_len, emb_dim).
136
137
138
139
140
    if inputs.shape[1] != pos_embedding.shape[1]:
      pos_embedding = self._interpolate(pos_embedding,
                                        from_shape=self.posemb_origin_shape,
                                        to_shape=self.posemb_target_shape)
    pos_embedding = tf.cast(pos_embedding, inputs.dtype)
Xianzhi Du's avatar
Xianzhi Du committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

    return inputs + pos_embedding


class TokenLayer(tf.keras.layers.Layer):
  """A simple layer to wrap token parameters."""

  def build(self, inputs_shape):
    self.cls = self.add_weight(
        'cls', (1, 1, inputs_shape[-1]), initializer='zeros')

  def call(self, inputs):
    cls = tf.cast(self.cls, inputs.dtype)
    cls = cls + tf.zeros_like(inputs[:, 0:1])  # A hacky way to tile.
    x = tf.concat([cls, inputs], axis=1)
    return x


class Encoder(tf.keras.layers.Layer):
  """Transformer Encoder."""

  def __init__(self,
               num_layers,
               mlp_dim,
               num_heads,
               dropout_rate=0.1,
               attention_dropout_rate=0.1,
               kernel_regularizer=None,
               inputs_positions=None,
170
171
               init_stochastic_depth_rate=0.0,
               kernel_initializer='glorot_uniform',
Frederick Liu's avatar
Frederick Liu committed
172
               add_pos_embed=True,
173
174
               pos_embed_origin_shape=None,
               pos_embed_target_shape=None,
Xianzhi Du's avatar
Xianzhi Du committed
175
176
177
178
179
180
181
182
183
               **kwargs):
    super().__init__(**kwargs)
    self._num_layers = num_layers
    self._mlp_dim = mlp_dim
    self._num_heads = num_heads
    self._dropout_rate = dropout_rate
    self._attention_dropout_rate = attention_dropout_rate
    self._kernel_regularizer = kernel_regularizer
    self._inputs_positions = inputs_positions
184
185
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
    self._kernel_initializer = kernel_initializer
Frederick Liu's avatar
Frederick Liu committed
186
    self._add_pos_embed = add_pos_embed
187
188
    self._pos_embed_origin_shape = pos_embed_origin_shape
    self._pos_embed_target_shape = pos_embed_target_shape
Xianzhi Du's avatar
Xianzhi Du committed
189
190

  def build(self, input_shape):
Frederick Liu's avatar
Frederick Liu committed
191
192
193
    if self._add_pos_embed:
      self._pos_embed = AddPositionEmbs(
          posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
194
195
          posemb_origin_shape=self._pos_embed_origin_shape,
          posemb_target_shape=self._pos_embed_target_shape,
Frederick Liu's avatar
Frederick Liu committed
196
          name='posembed_input')
Xianzhi Du's avatar
Xianzhi Du committed
197
198
199
200
    self._dropout = layers.Dropout(rate=self._dropout_rate)

    self._encoder_layers = []
    # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
Marc van Zee's avatar
Marc van Zee committed
201
    # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
202
    for i in range(self._num_layers):
203
      encoder_layer = nn_blocks.TransformerEncoderBlock(
Xianzhi Du's avatar
Xianzhi Du committed
204
205
206
207
208
209
          inner_activation=activations.gelu,
          num_attention_heads=self._num_heads,
          inner_dim=self._mlp_dim,
          output_dropout=self._dropout_rate,
          attention_dropout=self._attention_dropout_rate,
          kernel_regularizer=self._kernel_regularizer,
210
          kernel_initializer=self._kernel_initializer,
Xianzhi Du's avatar
Xianzhi Du committed
211
          norm_first=True,
212
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
213
              self._init_stochastic_depth_rate, i + 1, self._num_layers),
Xianzhi Du's avatar
Xianzhi Du committed
214
215
216
217
218
219
          norm_epsilon=1e-6)
      self._encoder_layers.append(encoder_layer)
    self._norm = layers.LayerNormalization(epsilon=1e-6)
    super().build(input_shape)

  def call(self, inputs, training=None):
Frederick Liu's avatar
Frederick Liu committed
220
221
222
    x = inputs
    if self._add_pos_embed:
      x = self._pos_embed(x, inputs_positions=self._inputs_positions)
Xianzhi Du's avatar
Xianzhi Du committed
223
224
225
226
227
228
229
    x = self._dropout(x, training=training)

    for encoder_layer in self._encoder_layers:
      x = encoder_layer(x, training=training)
    x = self._norm(x)
    return x

exx8's avatar
exx8 committed
230
231
232
233
234
235
236
237
238
239
240
241
242
243
  def get_config(self):
    config = {
        'num_layers': self._num_layers,
        'mlp_dim': self._mlp_dim,
        'num_heads': self._num_heads,
        'dropout_rate': self._dropout_rate,
        'attention_dropout_rate': self._attention_dropout_rate,
        'kernel_regularizer': self._kernel_regularizer,
        'inputs_positions': self._inputs_positions,
        'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
        'kernel_initializer': self._kernel_initializer,
        'add_pos_embed': self._add_pos_embed,
    }
    base_config = super().get_config()
244
    return base_config.update(config)
Xianzhi Du's avatar
Xianzhi Du committed
245
246
247
248
249
250
251
252
253
254
255


class VisionTransformer(tf.keras.Model):
  """Class to build VisionTransformer family model."""

  def __init__(self,
               mlp_dim=3072,
               num_heads=12,
               num_layers=12,
               attention_dropout_rate=0.0,
               dropout_rate=0.1,
256
               init_stochastic_depth_rate=0.0,
Xianzhi Du's avatar
Xianzhi Du committed
257
258
259
260
               input_specs=layers.InputSpec(shape=[None, None, None, 3]),
               patch_size=16,
               hidden_size=768,
               representation_size=0,
261
               pooler='token',
262
               kernel_regularizer=None,
263
264
               original_init: bool = True,
               pos_embed_shape: Optional[Tuple[int, int]] = None):
Xianzhi Du's avatar
Xianzhi Du committed
265
266
267
268
269
270
271
272
    """VisionTransformer initialization function."""
    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    x = layers.Conv2D(
        filters=hidden_size,
        kernel_size=patch_size,
        strides=patch_size,
        padding='valid',
273
274
        kernel_regularizer=kernel_regularizer,
        kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
Xianzhi Du's avatar
Xianzhi Du committed
275
276
277
278
279
280
281
282
283
284
            inputs)
    if tf.keras.backend.image_data_format() == 'channels_last':
      rows_axis, cols_axis = (1, 2)
    else:
      rows_axis, cols_axis = (2, 3)
      # The reshape below assumes the data_format is 'channels_last,' so
      # transpose to that. Once the data is flattened by the reshape, the
      # data_format is irrelevant, so no need to update
      # tf.keras.backend.image_data_format.
      x = tf.transpose(x, perm=[0, 2, 3, 1])
285
286

    pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
Xianzhi Du's avatar
Xianzhi Du committed
287
288
289
290
291
    seq_len = (input_specs.shape[rows_axis] // patch_size) * (
        input_specs.shape[cols_axis] // patch_size)
    x = tf.reshape(x, [-1, seq_len, hidden_size])

    # If we want to add a class token, add it here.
292
    if pooler == 'token':
Xianzhi Du's avatar
Xianzhi Du committed
293
294
295
296
297
298
299
300
      x = TokenLayer(name='cls')(x)

    x = Encoder(
        num_layers=num_layers,
        mlp_dim=mlp_dim,
        num_heads=num_heads,
        dropout_rate=dropout_rate,
        attention_dropout_rate=attention_dropout_rate,
301
302
303
        kernel_regularizer=kernel_regularizer,
        kernel_initializer='glorot_uniform' if original_init else dict(
            class_name='TruncatedNormal', config=dict(stddev=.02)),
304
305
306
        init_stochastic_depth_rate=init_stochastic_depth_rate,
        pos_embed_origin_shape=pos_embed_shape,
        pos_embed_target_shape=pos_embed_target_shape)(x)
Xianzhi Du's avatar
Xianzhi Du committed
307

308
    if pooler == 'token':
Xianzhi Du's avatar
Xianzhi Du committed
309
      x = x[:, 0]
310
    elif pooler == 'gap':
Xianzhi Du's avatar
Xianzhi Du committed
311
      x = tf.reduce_mean(x, axis=1)
312
313
    elif pooler == 'none':
      x = tf.identity(x, name='encoded_tokens')
314
    else:
315
      raise ValueError(f'unrecognized pooler type: {pooler}')
Xianzhi Du's avatar
Xianzhi Du committed
316
317
318
319
320

    if representation_size:
      x = tf.keras.layers.Dense(
          representation_size,
          kernel_regularizer=kernel_regularizer,
321
322
          name='pre_logits',
          kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
Xianzhi Du's avatar
Xianzhi Du committed
323
324
325
326
327
              x)
      x = tf.nn.tanh(x)
    else:
      x = tf.identity(x, name='pre_logits')

328
329
330
331
332
333
334
    if pooler == 'none':
      endpoints = {'encoded_tokens': x}
    else:
      endpoints = {
          'pre_logits':
              tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
      }
Xianzhi Du's avatar
Xianzhi Du committed
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
    super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)


@factory.register_backbone_builder('vit')
def build_vit(input_specs,
              backbone_config,
              norm_activation_config,
              l2_regularizer=None):
  """Build ViT model."""
  del norm_activation_config
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
  assert backbone_type == 'vit', (f'Inconsistent backbone type '
                                  f'{backbone_type}')
  backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])

  return VisionTransformer(
      mlp_dim=backbone_cfg.transformer.mlp_dim,
      num_heads=backbone_cfg.transformer.num_heads,
      num_layers=backbone_cfg.transformer.num_layers,
      attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
      dropout_rate=backbone_cfg.transformer.dropout_rate,
357
      init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
Xianzhi Du's avatar
Xianzhi Du committed
358
359
360
361
      input_specs=input_specs,
      patch_size=backbone_cfg.patch_size,
      hidden_size=backbone_cfg.hidden_size,
      representation_size=backbone_cfg.representation_size,
362
      pooler=backbone_cfg.pooler,
363
      kernel_regularizer=l2_regularizer,
364
365
      original_init=backbone_cfg.original_init,
      pos_embed_shape=backbone_cfg.pos_embed_shape)