Commit 27faf029 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Support ViT dynamic positional embedding interpolation. This is useful to...

Support ViT dynamic positional embedding interpolation. This is useful to finetune a model pre-trained on a different resolution.

PiperOrigin-RevId: 461666803
parent 0a64493c
......@@ -13,9 +13,8 @@
# limitations under the License.
"""Backbones configurations."""
from typing import Optional
import dataclasses
from typing import Optional, Tuple
from official.modeling import hyperparams
......@@ -43,6 +42,7 @@ class VisionTransformer(hyperparams.Config):
transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
pos_embed_shape: Optional[Tuple[int, int]] = None
@dataclasses.dataclass
......
......@@ -13,6 +13,9 @@
# limitations under the License.
"""VisionTransformer models."""
from typing import Optional, Tuple
from absl import logging
import immutabledict
import tensorflow as tf
......@@ -80,18 +83,61 @@ VIT_SPECS = immutabledict.immutabledict({
class AddPositionEmbs(tf.keras.layers.Layer):
"""Adds (optionally learned) positional embeddings to the inputs."""
def __init__(self, posemb_init=None, **kwargs):
def __init__(self,
posemb_init: Optional[tf.keras.initializers.Initializer] = None,
posemb_origin_shape: Optional[Tuple[int, int]] = None,
posemb_target_shape: Optional[Tuple[int, int]] = None,
**kwargs):
"""Constructs Postional Embedding module.
The logic of this module is: the learnable positional embeddings length will
be determined by the inputs_shape or posemb_origin_shape (if provided)
during the construction. If the posemb_target_shape is provided and is
different from the positional embeddings length, the embeddings will be
interpolated during the forward call.
Args:
posemb_init: The positional embedding initializer.
posemb_origin_shape: The intended positional embedding shape.
posemb_target_shape: The potential target shape positional embedding may
be interpolated to.
**kwargs: other args.
"""
super().__init__(**kwargs)
self.posemb_init = posemb_init
self.posemb_origin_shape = posemb_origin_shape
self.posemb_target_shape = posemb_target_shape
def build(self, inputs_shape):
pos_emb_shape = (1, inputs_shape[1], inputs_shape[2])
if self.posemb_origin_shape is not None:
pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
else:
pos_emb_length = inputs_shape[1]
pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
self.pos_embedding = self.add_weight(
'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
def _interpolate(self,
pos_embedding: tf.Tensor,
from_shape: Tuple[int, int],
to_shape: Tuple[int, int]) -> tf.Tensor:
"""Interpolates the positional embeddings."""
logging.info('Interpolating postional embedding from length: %d to %d',
from_shape, to_shape)
grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
# NOTE: Using BILINEAR interpolation by default.
grid_emb = tf.image.resize(grid_emb, to_shape)
return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])
def call(self, inputs, inputs_positions=None):
del inputs_positions
pos_embedding = self.pos_embedding
# inputs.shape is (batch_size, seq_len, emb_dim).
pos_embedding = tf.cast(self.pos_embedding, inputs.dtype)
if inputs.shape[1] != pos_embedding.shape[1]:
pos_embedding = self._interpolate(pos_embedding,
from_shape=self.posemb_origin_shape,
to_shape=self.posemb_target_shape)
pos_embedding = tf.cast(pos_embedding, inputs.dtype)
return inputs + pos_embedding
......@@ -124,6 +170,8 @@ class Encoder(tf.keras.layers.Layer):
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
add_pos_embed=True,
pos_embed_origin_shape=None,
pos_embed_target_shape=None,
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
......@@ -136,11 +184,15 @@ class Encoder(tf.keras.layers.Layer):
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed
self._pos_embed_origin_shape = pos_embed_origin_shape
self._pos_embed_target_shape = pos_embed_target_shape
def build(self, input_shape):
if self._add_pos_embed:
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
posemb_origin_shape=self._pos_embed_origin_shape,
posemb_target_shape=self._pos_embed_target_shape,
name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate)
......@@ -208,7 +260,8 @@ class VisionTransformer(tf.keras.Model):
representation_size=0,
classifier='token',
kernel_regularizer=None,
original_init=True):
original_init: bool = True,
pos_embed_shape: Optional[Tuple[int, int]] = None):
"""VisionTransformer initialization function."""
inputs = tf.keras.Input(shape=input_specs.shape[1:])
......@@ -229,6 +282,8 @@ class VisionTransformer(tf.keras.Model):
# data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format.
x = tf.transpose(x, perm=[0, 2, 3, 1])
pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
seq_len = (input_specs.shape[rows_axis] // patch_size) * (
input_specs.shape[cols_axis] // patch_size)
x = tf.reshape(x, [-1, seq_len, hidden_size])
......@@ -246,13 +301,16 @@ class VisionTransformer(tf.keras.Model):
kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate)(
x)
init_stochastic_depth_rate=init_stochastic_depth_rate,
pos_embed_origin_shape=pos_embed_shape,
pos_embed_target_shape=pos_embed_target_shape)(x)
if classifier == 'token':
x = x[:, 0]
elif classifier == 'gap':
x = tf.reduce_mean(x, axis=1)
else:
raise ValueError(f'unrecognized classifier type: {classifier}')
if representation_size:
x = tf.keras.layers.Dense(
......@@ -298,4 +356,5 @@ def build_vit(input_specs,
representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier,
kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init)
original_init=backbone_cfg.original_init,
pos_embed_shape=backbone_cfg.pos_embed_shape)
......@@ -37,6 +37,21 @@ class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
_ = network(inputs)
self.assertEqual(network.count_params(), params_count)
def test_posembedding_interpolation(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
classifier='gap',
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['pre_logits']
self.assertEqual(output.shape, [1, 1, 1, 768])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment