"vscode:/vscode.git/clone" did not exist on "bc71d8e9e155d34a38af8489ad4cbb2fde6fa152"
Commit 219f6f06 authored by Xianzhi Du's avatar Xianzhi Du Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 380237223
parent c5783656
# Vision Transformer (ViT)
**DISCLAIMER**: This implementation is still under development. No support will
be provided during the development phase.
[![Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
This repository is the implementations of Vision Transformer (ViT) in
TensorFlow 2.
* Paper title:
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.vision.beta.projects.vit.configs import image_classification
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones configurations."""
from typing import Optional
import dataclasses
from official.modeling import hyperparams
@dataclasses.dataclass
class Transformer(hyperparams.Config):
"""Transformer config."""
mlp_dim: int = 1
num_heads: int = 1
num_layers: int = 1
attention_dropout_rate: float = 0.0
dropout_rate: float = 0.1
@dataclasses.dataclass
class VisionTransformer(hyperparams.Config):
"""VisionTransformer config."""
model_name: str = 'vit-b16'
# pylint: disable=line-too-long
classifier: str = 'token' # 'token' or 'gap'. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size: int = 0
hidden_size: int = 1
patch_size: int = 16
transformer: Transformer = Transformer()
@dataclasses.dataclass
class Backbone(hyperparams.OneOfConfig):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, one the of fields below.
vit: vit backbone config.
"""
type: Optional[str] = None
vit: VisionTransformer = VisionTransformer()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.core import task_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import common
from official.vision.beta.configs import image_classification as img_cls_cfg
from official.vision.beta.projects.vit.configs import backbones
from official.vision.beta.tasks import image_classification
DataConfig = img_cls_cfg.DataConfig
@dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config):
"""The model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='vit', vit=backbones.VisionTransformer())
dropout_rate: float = 0.0
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
@dataclasses.dataclass
class Losses(hyperparams.Config):
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
top_k: int = 5
@dataclasses.dataclass
class ImageClassificationTask(cfg.TaskConfig):
"""The task config. Same as the classification task for convnets."""
model: ImageClassificationModel = ImageClassificationModel()
train_data: DataConfig = DataConfig(is_training=True)
validation_data: DataConfig = DataConfig(is_training=False)
losses: Losses = Losses()
evaluation: Evaluation = Evaluation()
init_checkpoint: Optional[str] = None
init_checkpoint_modules: str = 'all' # all or backbone
IMAGENET_TRAIN_EXAMPLES = 1281167
IMAGENET_VAL_EXAMPLES = 50000
IMAGENET_INPUT_PATH_BASE = 'imagenet-2012-tfrecord'
# TODO(b/177942984): integrate the experiments to TF-vision.
task_factory.register_task_cls(ImageClassificationTask)(
image_classification.ImageClassificationTask)
@exp_factory.register_config_factory('vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16', representation_size=768))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$',
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 10000,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_finetune')
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 512
eval_batch_size = 512
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[384, 384, 3],
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(model_name='vit-b16'))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=20000,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9,
'global_clipnorm': 1.0,
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003,
'decay_steps': 20000,
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""VisionTransformer models."""
import tensorflow as tf
from official.modeling import activations
from official.nlp import keras_nlp
from official.vision.beta.modeling.backbones import factory
layers = tf.keras.layers
VIT_SPECS = {
'vit-testing':
dict(
hidden_size=1,
patch_size=16,
transformer=dict(mlp_dim=1, num_heads=1, num_layers=1),
),
'vit-b16':
dict(
hidden_size=768,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-b32':
dict(
hidden_size=768,
patch_size=32,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-l16':
dict(
hidden_size=1024,
patch_size=16,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-l32':
dict(
hidden_size=1024,
patch_size=32,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-h14':
dict(
hidden_size=1280,
patch_size=14,
transformer=dict(mlp_dim=5120, num_heads=16, num_layers=32),
),
}
class AddPositionEmbs(tf.keras.layers.Layer):
"""Adds (optionally learned) positional embeddings to the inputs."""
def __init__(self, posemb_init=None, **kwargs):
super().__init__(**kwargs)
self.posemb_init = posemb_init
def build(self, inputs_shape):
pos_emb_shape = (1, inputs_shape[1], inputs_shape[2])
self.pos_embedding = self.add_weight(
'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
def call(self, inputs, inputs_positions=None):
# inputs.shape is (batch_size, seq_len, emb_dim).
pos_embedding = tf.cast(self.pos_embedding, inputs.dtype)
return inputs + pos_embedding
class TokenLayer(tf.keras.layers.Layer):
"""A simple layer to wrap token parameters."""
def build(self, inputs_shape):
self.cls = self.add_weight(
'cls', (1, 1, inputs_shape[-1]), initializer='zeros')
def call(self, inputs):
cls = tf.cast(self.cls, inputs.dtype)
cls = cls + tf.zeros_like(inputs[:, 0:1]) # A hacky way to tile.
x = tf.concat([cls, inputs], axis=1)
return x
class Encoder(tf.keras.layers.Layer):
"""Transformer Encoder."""
def __init__(self,
num_layers,
mlp_dim,
num_heads,
dropout_rate=0.1,
attention_dropout_rate=0.1,
kernel_regularizer=None,
inputs_positions=None,
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
self._mlp_dim = mlp_dim
self._num_heads = num_heads
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions
def build(self, input_shape):
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate)
self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for _ in range(self._num_layers):
encoder_layer = keras_nlp.layers.TransformerEncoderBlock(
inner_activation=activations.gelu,
num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer,
norm_first=True,
norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6)
super().build(input_shape)
def call(self, inputs, training=None):
x = self._pos_embed(inputs, inputs_positions=self._inputs_positions)
x = self._dropout(x, training=training)
for encoder_layer in self._encoder_layers:
x = encoder_layer(x, training=training)
x = self._norm(x)
return x
class VisionTransformer(tf.keras.Model):
"""Class to build VisionTransformer family model."""
def __init__(self,
mlp_dim=3072,
num_heads=12,
num_layers=12,
attention_dropout_rate=0.0,
dropout_rate=0.1,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16,
hidden_size=768,
representation_size=0,
classifier='token',
kernel_regularizer=None):
"""VisionTransformer initialization function."""
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = layers.Conv2D(
filters=hidden_size,
kernel_size=patch_size,
strides=patch_size,
padding='valid',
kernel_regularizer=kernel_regularizer)(
inputs)
if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2)
else:
rows_axis, cols_axis = (2, 3)
# The reshape below assumes the data_format is 'channels_last,' so
# transpose to that. Once the data is flattened by the reshape, the
# data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format.
x = tf.transpose(x, perm=[0, 2, 3, 1])
seq_len = (input_specs.shape[rows_axis] // patch_size) * (
input_specs.shape[cols_axis] // patch_size)
x = tf.reshape(x, [-1, seq_len, hidden_size])
# If we want to add a class token, add it here.
if classifier == 'token':
x = TokenLayer(name='cls')(x)
x = Encoder(
num_layers=num_layers,
mlp_dim=mlp_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer)(
x)
if classifier == 'token':
x = x[:, 0]
elif classifier == 'gap':
x = tf.reduce_mean(x, axis=1)
if representation_size:
x = tf.keras.layers.Dense(
representation_size,
kernel_regularizer=kernel_regularizer,
name='pre_logits')(
x)
x = tf.nn.tanh(x)
else:
x = tf.identity(x, name='pre_logits')
endpoints = {
'pre_logits':
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
}
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit')
def build_vit(input_specs,
backbone_config,
norm_activation_config,
l2_regularizer=None):
"""Build ViT model."""
del norm_activation_config
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'vit', (f'Inconsistent backbone type '
f'{backbone_type}')
backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])
return VisionTransformer(
mlp_dim=backbone_cfg.transformer.mlp_dim,
num_heads=backbone_cfg.transformer.num_heads,
num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate,
input_specs=input_specs,
patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier,
kernel_regularizer=l2_regularizer)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for VIT."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.vit.modeling import vit
class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(224, 85798656),
(256, 85844736),
)
def test_network_creation(self, input_size, params_count):
"""Test creation of VisionTransformer family models."""
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(input_specs=input_specs)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
self.assertEqual(network.count_params(), params_count)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver, including ViT configs.."""
from absl import app
from official.common import flags as tfm_flags
from official.vision.beta import train
from official.vision.beta.projects.vit import configs # pylint: disable=unused-import
from official.vision.beta.projects.vit.modeling import vit # pylint: disable=unused-import
if __name__ == '__main__':
tfm_flags.define_flags()
app.run(train.main)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment