vit.py 12.1 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Xianzhi Du's avatar
Xianzhi Du committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""VisionTransformer models."""
16
17
18
from typing import Optional, Tuple

from absl import logging
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
19

Yeqing Li's avatar
Yeqing Li committed
20
# import immutabledict
Xianzhi Du's avatar
Xianzhi Du committed
21
22
23
import tensorflow as tf

from official.modeling import activations
Xianzhi Du's avatar
Xianzhi Du committed
24
from official.projects.vit.modeling import nn_blocks
Yeqing Li's avatar
Yeqing Li committed
25
from official.projects.vit.modeling.vit_specs import VIT_SPECS
Xianzhi Du's avatar
Xianzhi Du committed
26
27
from official.vision.modeling.backbones import factory
from official.vision.modeling.layers import nn_layers
Xianzhi Du's avatar
Xianzhi Du committed
28
29
30
31
32
33
34

layers = tf.keras.layers


class AddPositionEmbs(tf.keras.layers.Layer):
  """Adds (optionally learned) positional embeddings to the inputs."""

35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
  def __init__(self,
               posemb_init: Optional[tf.keras.initializers.Initializer] = None,
               posemb_origin_shape: Optional[Tuple[int, int]] = None,
               posemb_target_shape: Optional[Tuple[int, int]] = None,
               **kwargs):
    """Constructs Postional Embedding module.

    The logic of this module is: the learnable positional embeddings length will
    be determined by the inputs_shape or posemb_origin_shape (if provided)
    during the construction. If the posemb_target_shape is provided and is
    different from the positional embeddings length, the embeddings will be
    interpolated during the forward call.

    Args:
      posemb_init: The positional embedding initializer.
      posemb_origin_shape: The intended positional embedding shape.
      posemb_target_shape: The potential target shape positional embedding may
        be interpolated to.
      **kwargs: other args.
    """
Xianzhi Du's avatar
Xianzhi Du committed
55
56
    super().__init__(**kwargs)
    self.posemb_init = posemb_init
57
58
    self.posemb_origin_shape = posemb_origin_shape
    self.posemb_target_shape = posemb_target_shape
Xianzhi Du's avatar
Xianzhi Du committed
59
60

  def build(self, inputs_shape):
61
62
63
64
65
    if self.posemb_origin_shape is not None:
      pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
    else:
      pos_emb_length = inputs_shape[1]
    pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
Xianzhi Du's avatar
Xianzhi Du committed
66
67
68
    self.pos_embedding = self.add_weight(
        'pos_embedding', pos_emb_shape, initializer=self.posemb_init)

Chaochao Yan's avatar
Chaochao Yan committed
69
  def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int],
70
71
72
73
74
75
76
77
78
                   to_shape: Tuple[int, int]) -> tf.Tensor:
    """Interpolates the positional embeddings."""
    logging.info('Interpolating postional embedding from length: %d to %d',
                 from_shape, to_shape)
    grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
    # NOTE: Using BILINEAR interpolation by default.
    grid_emb = tf.image.resize(grid_emb, to_shape)
    return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])

Xianzhi Du's avatar
Xianzhi Du committed
79
  def call(self, inputs, inputs_positions=None):
80
81
    del inputs_positions
    pos_embedding = self.pos_embedding
Xianzhi Du's avatar
Xianzhi Du committed
82
    # inputs.shape is (batch_size, seq_len, emb_dim).
83
    if inputs.shape[1] != pos_embedding.shape[1]:
Chaochao Yan's avatar
Chaochao Yan committed
84
85
86
87
      pos_embedding = self._interpolate(
          pos_embedding,
          from_shape=self.posemb_origin_shape,
          to_shape=self.posemb_target_shape)
88
    pos_embedding = tf.cast(pos_embedding, inputs.dtype)
Xianzhi Du's avatar
Xianzhi Du committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117

    return inputs + pos_embedding


class TokenLayer(tf.keras.layers.Layer):
  """A simple layer to wrap token parameters."""

  def build(self, inputs_shape):
    self.cls = self.add_weight(
        'cls', (1, 1, inputs_shape[-1]), initializer='zeros')

  def call(self, inputs):
    cls = tf.cast(self.cls, inputs.dtype)
    cls = cls + tf.zeros_like(inputs[:, 0:1])  # A hacky way to tile.
    x = tf.concat([cls, inputs], axis=1)
    return x


class Encoder(tf.keras.layers.Layer):
  """Transformer Encoder."""

  def __init__(self,
               num_layers,
               mlp_dim,
               num_heads,
               dropout_rate=0.1,
               attention_dropout_rate=0.1,
               kernel_regularizer=None,
               inputs_positions=None,
118
119
               init_stochastic_depth_rate=0.0,
               kernel_initializer='glorot_uniform',
Frederick Liu's avatar
Frederick Liu committed
120
               add_pos_embed=True,
121
122
               pos_embed_origin_shape=None,
               pos_embed_target_shape=None,
Xianzhi Du's avatar
Xianzhi Du committed
123
124
125
126
127
128
129
130
131
               **kwargs):
    super().__init__(**kwargs)
    self._num_layers = num_layers
    self._mlp_dim = mlp_dim
    self._num_heads = num_heads
    self._dropout_rate = dropout_rate
    self._attention_dropout_rate = attention_dropout_rate
    self._kernel_regularizer = kernel_regularizer
    self._inputs_positions = inputs_positions
132
133
    self._init_stochastic_depth_rate = init_stochastic_depth_rate
    self._kernel_initializer = kernel_initializer
Frederick Liu's avatar
Frederick Liu committed
134
    self._add_pos_embed = add_pos_embed
135
136
    self._pos_embed_origin_shape = pos_embed_origin_shape
    self._pos_embed_target_shape = pos_embed_target_shape
Xianzhi Du's avatar
Xianzhi Du committed
137
138

  def build(self, input_shape):
Frederick Liu's avatar
Frederick Liu committed
139
140
141
    if self._add_pos_embed:
      self._pos_embed = AddPositionEmbs(
          posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
142
143
          posemb_origin_shape=self._pos_embed_origin_shape,
          posemb_target_shape=self._pos_embed_target_shape,
Frederick Liu's avatar
Frederick Liu committed
144
          name='posembed_input')
Xianzhi Du's avatar
Xianzhi Du committed
145
146
147
148
    self._dropout = layers.Dropout(rate=self._dropout_rate)

    self._encoder_layers = []
    # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
Marc van Zee's avatar
Marc van Zee committed
149
    # https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
150
    for i in range(self._num_layers):
151
      encoder_layer = nn_blocks.TransformerEncoderBlock(
Xianzhi Du's avatar
Xianzhi Du committed
152
153
154
155
156
157
          inner_activation=activations.gelu,
          num_attention_heads=self._num_heads,
          inner_dim=self._mlp_dim,
          output_dropout=self._dropout_rate,
          attention_dropout=self._attention_dropout_rate,
          kernel_regularizer=self._kernel_regularizer,
158
          kernel_initializer=self._kernel_initializer,
Xianzhi Du's avatar
Xianzhi Du committed
159
          norm_first=True,
160
          stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
161
              self._init_stochastic_depth_rate, i + 1, self._num_layers),
Xianzhi Du's avatar
Xianzhi Du committed
162
163
164
165
166
167
          norm_epsilon=1e-6)
      self._encoder_layers.append(encoder_layer)
    self._norm = layers.LayerNormalization(epsilon=1e-6)
    super().build(input_shape)

  def call(self, inputs, training=None):
Frederick Liu's avatar
Frederick Liu committed
168
169
170
    x = inputs
    if self._add_pos_embed:
      x = self._pos_embed(x, inputs_positions=self._inputs_positions)
Xianzhi Du's avatar
Xianzhi Du committed
171
172
173
174
175
176
177
    x = self._dropout(x, training=training)

    for encoder_layer in self._encoder_layers:
      x = encoder_layer(x, training=training)
    x = self._norm(x)
    return x

exx8's avatar
exx8 committed
178
  def get_config(self):
Ellery Wulczyn's avatar
Ellery Wulczyn committed
179
180
    config = super().get_config()
    updates = {
exx8's avatar
exx8 committed
181
182
183
184
185
186
187
188
189
190
        'num_layers': self._num_layers,
        'mlp_dim': self._mlp_dim,
        'num_heads': self._num_heads,
        'dropout_rate': self._dropout_rate,
        'attention_dropout_rate': self._attention_dropout_rate,
        'kernel_regularizer': self._kernel_regularizer,
        'inputs_positions': self._inputs_positions,
        'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
        'kernel_initializer': self._kernel_initializer,
        'add_pos_embed': self._add_pos_embed,
Ellery Wulczyn's avatar
Ellery Wulczyn committed
191
192
        'pos_embed_origin_shape': self._pos_embed_origin_shape,
        'pos_embed_target_shape': self._pos_embed_target_shape,
exx8's avatar
exx8 committed
193
    }
Ellery Wulczyn's avatar
Ellery Wulczyn committed
194
195
    config.update(updates)
    return config
Xianzhi Du's avatar
Xianzhi Du committed
196

Xianzhi Du's avatar
Xianzhi Du committed
197
198
199
200
201
202
203
204
205
206

class VisionTransformer(tf.keras.Model):
  """Class to build VisionTransformer family model."""

  def __init__(self,
               mlp_dim=3072,
               num_heads=12,
               num_layers=12,
               attention_dropout_rate=0.0,
               dropout_rate=0.1,
207
               init_stochastic_depth_rate=0.0,
Xianzhi Du's avatar
Xianzhi Du committed
208
209
210
211
               input_specs=layers.InputSpec(shape=[None, None, None, 3]),
               patch_size=16,
               hidden_size=768,
               representation_size=0,
212
               pooler='token',
213
               kernel_regularizer=None,
214
215
               original_init: bool = True,
               pos_embed_shape: Optional[Tuple[int, int]] = None):
Xianzhi Du's avatar
Xianzhi Du committed
216
    """VisionTransformer initialization function."""
217
218
219
220
221
222
    self._mlp_dim = mlp_dim
    self._num_heads = num_heads
    self._num_layers = num_layers
    self._hidden_size = hidden_size
    self._patch_size = patch_size

Xianzhi Du's avatar
Xianzhi Du committed
223
224
225
226
227
228
229
    inputs = tf.keras.Input(shape=input_specs.shape[1:])

    x = layers.Conv2D(
        filters=hidden_size,
        kernel_size=patch_size,
        strides=patch_size,
        padding='valid',
230
231
        kernel_regularizer=kernel_regularizer,
        kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
Xianzhi Du's avatar
Xianzhi Du committed
232
233
234
235
236
237
238
239
240
241
            inputs)
    if tf.keras.backend.image_data_format() == 'channels_last':
      rows_axis, cols_axis = (1, 2)
    else:
      rows_axis, cols_axis = (2, 3)
      # The reshape below assumes the data_format is 'channels_last,' so
      # transpose to that. Once the data is flattened by the reshape, the
      # data_format is irrelevant, so no need to update
      # tf.keras.backend.image_data_format.
      x = tf.transpose(x, perm=[0, 2, 3, 1])
242
243

    pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
Xianzhi Du's avatar
Xianzhi Du committed
244
245
246
247
248
    seq_len = (input_specs.shape[rows_axis] // patch_size) * (
        input_specs.shape[cols_axis] // patch_size)
    x = tf.reshape(x, [-1, seq_len, hidden_size])

    # If we want to add a class token, add it here.
249
    if pooler == 'token':
Xianzhi Du's avatar
Xianzhi Du committed
250
251
252
253
254
255
256
257
      x = TokenLayer(name='cls')(x)

    x = Encoder(
        num_layers=num_layers,
        mlp_dim=mlp_dim,
        num_heads=num_heads,
        dropout_rate=dropout_rate,
        attention_dropout_rate=attention_dropout_rate,
258
259
260
        kernel_regularizer=kernel_regularizer,
        kernel_initializer='glorot_uniform' if original_init else dict(
            class_name='TruncatedNormal', config=dict(stddev=.02)),
261
262
        init_stochastic_depth_rate=init_stochastic_depth_rate,
        pos_embed_origin_shape=pos_embed_shape,
Chaochao Yan's avatar
Chaochao Yan committed
263
264
        pos_embed_target_shape=pos_embed_target_shape)(
            x)
Xianzhi Du's avatar
Xianzhi Du committed
265

266
    if pooler == 'token':
Xianzhi Du's avatar
Xianzhi Du committed
267
      x = x[:, 0]
268
    elif pooler == 'gap':
Xianzhi Du's avatar
Xianzhi Du committed
269
      x = tf.reduce_mean(x, axis=1)
270
271
    elif pooler == 'none':
      x = tf.identity(x, name='encoded_tokens')
272
    else:
273
      raise ValueError(f'unrecognized pooler type: {pooler}')
Xianzhi Du's avatar
Xianzhi Du committed
274
275
276
277
278

    if representation_size:
      x = tf.keras.layers.Dense(
          representation_size,
          kernel_regularizer=kernel_regularizer,
279
280
          name='pre_logits',
          kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
Xianzhi Du's avatar
Xianzhi Du committed
281
282
283
284
285
              x)
      x = tf.nn.tanh(x)
    else:
      x = tf.identity(x, name='pre_logits')

286
287
288
289
290
291
292
    if pooler == 'none':
      endpoints = {'encoded_tokens': x}
    else:
      endpoints = {
          'pre_logits':
              tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
      }
Xianzhi Du's avatar
Xianzhi Du committed
293
294
295
    super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)


Chaochao Yan's avatar
Chaochao Yan committed
296
@factory.register_backbone_builder('legacy_vit')
Xianzhi Du's avatar
Xianzhi Du committed
297
298
299
300
301
302
303
304
def build_vit(input_specs,
              backbone_config,
              norm_activation_config,
              l2_regularizer=None):
  """Build ViT model."""
  del norm_activation_config
  backbone_type = backbone_config.type
  backbone_cfg = backbone_config.get()
Chaochao Yan's avatar
Chaochao Yan committed
305
306
  assert backbone_type == 'legacy_vit', (f'Inconsistent backbone type '
                                         f'{backbone_type}')
Xianzhi Du's avatar
Xianzhi Du committed
307
308
309
310
311
312
313
314
  backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])

  return VisionTransformer(
      mlp_dim=backbone_cfg.transformer.mlp_dim,
      num_heads=backbone_cfg.transformer.num_heads,
      num_layers=backbone_cfg.transformer.num_layers,
      attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
      dropout_rate=backbone_cfg.transformer.dropout_rate,
315
      init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
Xianzhi Du's avatar
Xianzhi Du committed
316
317
318
319
      input_specs=input_specs,
      patch_size=backbone_cfg.patch_size,
      hidden_size=backbone_cfg.hidden_size,
      representation_size=backbone_cfg.representation_size,
320
      pooler=backbone_cfg.pooler,
321
      kernel_regularizer=l2_regularizer,
322
323
      original_init=backbone_cfg.original_init,
      pos_embed_shape=backbone_cfg.pos_embed_shape)