Commit 423b506a authored by Liangzhe Yuan's avatar Liangzhe Yuan Committed by A. Unique TensorFlower
Browse files

Support 'none' classifier type to ViT. Also rename 'classifier' to 'pooler' for better naming.

PiperOrigin-RevId: 461706948
parent 1071c54e
......@@ -34,7 +34,7 @@ class VisionTransformer(hyperparams.Config):
"""VisionTransformer config."""
model_name: str = 'vit-b16'
# pylint: disable=line-too-long
classifier: str = 'token' # 'token' or 'gap'. If set to 'token', an extra classification token is added to sequence.
pooler: str = 'token' # 'token', 'gap' or 'none'. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size: int = 0
hidden_size: int = 1
......
......@@ -258,7 +258,7 @@ class VisionTransformer(tf.keras.Model):
patch_size=16,
hidden_size=768,
representation_size=0,
classifier='token',
pooler='token',
kernel_regularizer=None,
original_init: bool = True,
pos_embed_shape: Optional[Tuple[int, int]] = None):
......@@ -289,7 +289,7 @@ class VisionTransformer(tf.keras.Model):
x = tf.reshape(x, [-1, seq_len, hidden_size])
# If we want to add a class token, add it here.
if classifier == 'token':
if pooler == 'token':
x = TokenLayer(name='cls')(x)
x = Encoder(
......@@ -305,12 +305,14 @@ class VisionTransformer(tf.keras.Model):
pos_embed_origin_shape=pos_embed_shape,
pos_embed_target_shape=pos_embed_target_shape)(x)
if classifier == 'token':
if pooler == 'token':
x = x[:, 0]
elif classifier == 'gap':
elif pooler == 'gap':
x = tf.reduce_mean(x, axis=1)
elif pooler == 'none':
x = tf.identity(x, name='encoded_tokens')
else:
raise ValueError(f'unrecognized classifier type: {classifier}')
raise ValueError(f'unrecognized pooler type: {pooler}')
if representation_size:
x = tf.keras.layers.Dense(
......@@ -322,11 +324,14 @@ class VisionTransformer(tf.keras.Model):
x = tf.nn.tanh(x)
else:
x = tf.identity(x, name='pre_logits')
endpoints = {
'pre_logits':
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
}
if pooler == 'none':
endpoints = {'encoded_tokens': x}
else:
endpoints = {
'pre_logits':
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
}
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
......@@ -354,7 +359,7 @@ def build_vit(input_specs,
patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier,
pooler=backbone_cfg.pooler,
kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init,
pos_embed_shape=backbone_cfg.pos_embed_shape)
......@@ -37,6 +37,22 @@ class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
_ = network(inputs)
self.assertEqual(network.count_params(), params_count)
def test_network_none_pooler(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='none',
representation_size=128,
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['encoded_tokens']
self.assertEqual(output.shape, [1, 256, 128])
def test_posembedding_interpolation(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
......@@ -45,7 +61,7 @@ class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
classifier='gap',
pooler='gap',
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment