Commit 27faf029 authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Support ViT dynamic positional embedding interpolation. This is useful to...

Support ViT dynamic positional embedding interpolation. This is useful to finetune a model pre-trained on a different resolution.

PiperOrigin-RevId: 461666803
parent 0a64493c
...@@ -13,9 +13,8 @@ ...@@ -13,9 +13,8 @@
# limitations under the License. # limitations under the License.
"""Backbones configurations.""" """Backbones configurations."""
from typing import Optional
import dataclasses import dataclasses
from typing import Optional, Tuple
from official.modeling import hyperparams from official.modeling import hyperparams
...@@ -43,6 +42,7 @@ class VisionTransformer(hyperparams.Config): ...@@ -43,6 +42,7 @@ class VisionTransformer(hyperparams.Config):
transformer: Transformer = Transformer() transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0 init_stochastic_depth_rate: float = 0.0
original_init: bool = True original_init: bool = True
pos_embed_shape: Optional[Tuple[int, int]] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -13,6 +13,9 @@ ...@@ -13,6 +13,9 @@
# limitations under the License. # limitations under the License.
"""VisionTransformer models.""" """VisionTransformer models."""
from typing import Optional, Tuple
from absl import logging
import immutabledict import immutabledict
import tensorflow as tf import tensorflow as tf
...@@ -80,18 +83,61 @@ VIT_SPECS = immutabledict.immutabledict({ ...@@ -80,18 +83,61 @@ VIT_SPECS = immutabledict.immutabledict({
class AddPositionEmbs(tf.keras.layers.Layer): class AddPositionEmbs(tf.keras.layers.Layer):
"""Adds (optionally learned) positional embeddings to the inputs.""" """Adds (optionally learned) positional embeddings to the inputs."""
def __init__(self, posemb_init=None, **kwargs): def __init__(self,
posemb_init: Optional[tf.keras.initializers.Initializer] = None,
posemb_origin_shape: Optional[Tuple[int, int]] = None,
posemb_target_shape: Optional[Tuple[int, int]] = None,
**kwargs):
"""Constructs Postional Embedding module.
The logic of this module is: the learnable positional embeddings length will
be determined by the inputs_shape or posemb_origin_shape (if provided)
during the construction. If the posemb_target_shape is provided and is
different from the positional embeddings length, the embeddings will be
interpolated during the forward call.
Args:
posemb_init: The positional embedding initializer.
posemb_origin_shape: The intended positional embedding shape.
posemb_target_shape: The potential target shape positional embedding may
be interpolated to.
**kwargs: other args.
"""
super().__init__(**kwargs) super().__init__(**kwargs)
self.posemb_init = posemb_init self.posemb_init = posemb_init
self.posemb_origin_shape = posemb_origin_shape
self.posemb_target_shape = posemb_target_shape
def build(self, inputs_shape): def build(self, inputs_shape):
pos_emb_shape = (1, inputs_shape[1], inputs_shape[2]) if self.posemb_origin_shape is not None:
pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
else:
pos_emb_length = inputs_shape[1]
pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
self.pos_embedding = self.add_weight( self.pos_embedding = self.add_weight(
'pos_embedding', pos_emb_shape, initializer=self.posemb_init) 'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
def _interpolate(self,
pos_embedding: tf.Tensor,
from_shape: Tuple[int, int],
to_shape: Tuple[int, int]) -> tf.Tensor:
"""Interpolates the positional embeddings."""
logging.info('Interpolating postional embedding from length: %d to %d',
from_shape, to_shape)
grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
# NOTE: Using BILINEAR interpolation by default.
grid_emb = tf.image.resize(grid_emb, to_shape)
return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])
def call(self, inputs, inputs_positions=None): def call(self, inputs, inputs_positions=None):
del inputs_positions
pos_embedding = self.pos_embedding
# inputs.shape is (batch_size, seq_len, emb_dim). # inputs.shape is (batch_size, seq_len, emb_dim).
pos_embedding = tf.cast(self.pos_embedding, inputs.dtype) if inputs.shape[1] != pos_embedding.shape[1]:
pos_embedding = self._interpolate(pos_embedding,
from_shape=self.posemb_origin_shape,
to_shape=self.posemb_target_shape)
pos_embedding = tf.cast(pos_embedding, inputs.dtype)
return inputs + pos_embedding return inputs + pos_embedding
...@@ -124,6 +170,8 @@ class Encoder(tf.keras.layers.Layer): ...@@ -124,6 +170,8 @@ class Encoder(tf.keras.layers.Layer):
init_stochastic_depth_rate=0.0, init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform', kernel_initializer='glorot_uniform',
add_pos_embed=True, add_pos_embed=True,
pos_embed_origin_shape=None,
pos_embed_target_shape=None,
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._num_layers = num_layers self._num_layers = num_layers
...@@ -136,11 +184,15 @@ class Encoder(tf.keras.layers.Layer): ...@@ -136,11 +184,15 @@ class Encoder(tf.keras.layers.Layer):
self._init_stochastic_depth_rate = init_stochastic_depth_rate self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed self._add_pos_embed = add_pos_embed
self._pos_embed_origin_shape = pos_embed_origin_shape
self._pos_embed_target_shape = pos_embed_target_shape
def build(self, input_shape): def build(self, input_shape):
if self._add_pos_embed: if self._add_pos_embed:
self._pos_embed = AddPositionEmbs( self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02), posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
posemb_origin_shape=self._pos_embed_origin_shape,
posemb_target_shape=self._pos_embed_target_shape,
name='posembed_input') name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate) self._dropout = layers.Dropout(rate=self._dropout_rate)
...@@ -208,7 +260,8 @@ class VisionTransformer(tf.keras.Model): ...@@ -208,7 +260,8 @@ class VisionTransformer(tf.keras.Model):
representation_size=0, representation_size=0,
classifier='token', classifier='token',
kernel_regularizer=None, kernel_regularizer=None,
original_init=True): original_init: bool = True,
pos_embed_shape: Optional[Tuple[int, int]] = None):
"""VisionTransformer initialization function.""" """VisionTransformer initialization function."""
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
...@@ -229,6 +282,8 @@ class VisionTransformer(tf.keras.Model): ...@@ -229,6 +282,8 @@ class VisionTransformer(tf.keras.Model):
# data_format is irrelevant, so no need to update # data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format. # tf.keras.backend.image_data_format.
x = tf.transpose(x, perm=[0, 2, 3, 1]) x = tf.transpose(x, perm=[0, 2, 3, 1])
pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
seq_len = (input_specs.shape[rows_axis] // patch_size) * ( seq_len = (input_specs.shape[rows_axis] // patch_size) * (
input_specs.shape[cols_axis] // patch_size) input_specs.shape[cols_axis] // patch_size)
x = tf.reshape(x, [-1, seq_len, hidden_size]) x = tf.reshape(x, [-1, seq_len, hidden_size])
...@@ -246,13 +301,16 @@ class VisionTransformer(tf.keras.Model): ...@@ -246,13 +301,16 @@ class VisionTransformer(tf.keras.Model):
kernel_regularizer=kernel_regularizer, kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict( kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)), class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate)( init_stochastic_depth_rate=init_stochastic_depth_rate,
x) pos_embed_origin_shape=pos_embed_shape,
pos_embed_target_shape=pos_embed_target_shape)(x)
if classifier == 'token': if classifier == 'token':
x = x[:, 0] x = x[:, 0]
elif classifier == 'gap': elif classifier == 'gap':
x = tf.reduce_mean(x, axis=1) x = tf.reduce_mean(x, axis=1)
else:
raise ValueError(f'unrecognized classifier type: {classifier}')
if representation_size: if representation_size:
x = tf.keras.layers.Dense( x = tf.keras.layers.Dense(
...@@ -298,4 +356,5 @@ def build_vit(input_specs, ...@@ -298,4 +356,5 @@ def build_vit(input_specs,
representation_size=backbone_cfg.representation_size, representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier, classifier=backbone_cfg.classifier,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init) original_init=backbone_cfg.original_init,
pos_embed_shape=backbone_cfg.pos_embed_shape)
...@@ -37,6 +37,21 @@ class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase): ...@@ -37,6 +37,21 @@ class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
_ = network(inputs) _ = network(inputs)
self.assertEqual(network.count_params(), params_count) self.assertEqual(network.count_params(), params_count)
def test_posembedding_interpolation(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
classifier='gap',
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['pre_logits']
self.assertEqual(output.shape, [1, 1, 1, 768])
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment