Commit 90585434 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481733792
parent d309fff8
...@@ -75,7 +75,7 @@ task_factory.register_task_cls(ImageClassificationTask)( ...@@ -75,7 +75,7 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification.ImageClassificationTask) image_classification.ImageClassificationTask)
@exp_factory.register_config_factory('deit_imagenet_pretrain') @exp_factory.register_config_factory('legacy_deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig: def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer.""" """Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32 train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
...@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig: ...@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
return config return config
@exp_factory.register_config_factory('vit_imagenet_pretrain') @exp_factory.register_config_factory('legacy_vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer.""" """Image classification on imagenet with vision transformer."""
train_batch_size = 4096 train_batch_size = 4096
...@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: ...@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
return config return config
@exp_factory.register_config_factory('vit_imagenet_finetune') @exp_factory.register_config_factory('legacy_vit_imagenet_finetune')
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig: def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer.""" """Image classification on imagenet with vision transformer."""
train_batch_size = 512 train_batch_size = 512
......
...@@ -294,7 +294,7 @@ class VisionTransformer(tf.keras.Model): ...@@ -294,7 +294,7 @@ class VisionTransformer(tf.keras.Model):
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints) super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit') @factory.register_backbone_builder('legacy_vit')
def build_vit(input_specs, def build_vit(input_specs,
backbone_config, backbone_config,
norm_activation_config, norm_activation_config,
......
...@@ -14,13 +14,37 @@ ...@@ -14,13 +14,37 @@
"""Backbones configurations.""" """Backbones configurations."""
import dataclasses import dataclasses
from typing import Optional, List from typing import List, Optional, Tuple
# Import libraries
from official.modeling import hyperparams from official.modeling import hyperparams
@dataclasses.dataclass
class Transformer(hyperparams.Config):
"""Transformer config."""
mlp_dim: int = 1
num_heads: int = 1
num_layers: int = 1
attention_dropout_rate: float = 0.0
dropout_rate: float = 0.1
@dataclasses.dataclass
class VisionTransformer(hyperparams.Config):
"""VisionTransformer config."""
model_name: str = 'vit-b16'
# pylint: disable=line-too-long
pooler: str = 'token' # 'token', 'gap' or 'none'. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size: int = 0
hidden_size: int = 1
patch_size: int = 16
transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
pos_embed_shape: Optional[Tuple[int, int]] = None
@dataclasses.dataclass @dataclasses.dataclass
class ResNet(hyperparams.Config): class ResNet(hyperparams.Config):
"""ResNet config.""" """ResNet config."""
...@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig): ...@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile: mobile spinenet backbone config. spinenet_mobile: mobile spinenet backbone config.
mobilenet: mobilenet backbone config. mobilenet: mobilenet backbone config.
mobiledet: mobiledet backbone config. mobiledet: mobiledet backbone config.
vit: vision transformer backbone config.
""" """
type: Optional[str] = None type: Optional[str] = None
resnet: ResNet = ResNet() resnet: ResNet = ResNet()
...@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig): ...@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile: SpineNetMobile = SpineNetMobile() spinenet_mobile: SpineNetMobile = SpineNetMobile()
mobilenet: MobileNet = MobileNet() mobilenet: MobileNet = MobileNet()
mobiledet: MobileDet = MobileDet() mobiledet: MobileDet = MobileDet()
vit: VisionTransformer = VisionTransformer()
...@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig: ...@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
]) ])
return config return config
@exp_factory.register_config_factory('deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(
l2_weight_decay=0.0,
label_smoothing=label_smoothing,
one_hot=False,
soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_type=common.Augmentation(
type='randaug',
randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
label_smoothing=label_smoothing)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16', representation_size=768))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003 * train_batch_size / 4096,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 10000,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_finetune')
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 512
eval_batch_size = 512
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[384, 384, 3],
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(model_name='vit-b16'))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=20000,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9,
'global_clipnorm': 1.0,
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003,
'decay_steps': 20000,
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
...@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
('resnet_imagenet',), ('resnet_imagenet',),
('resnet_rs_imagenet',), ('resnet_rs_imagenet',),
('revnet_imagenet',), ('revnet_imagenet',),
('mobilenet_imagenet'), ('mobilenet_imagenet',),
('deit_imagenet_pretrain',),
('vit_imagenet_pretrain',),
('vit_imagenet_finetune',),
) )
def test_image_classification_configs(self, config_name): def test_image_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
......
...@@ -23,3 +23,4 @@ from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet ...@@ -23,3 +23,4 @@ from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet
from official.vision.modeling.backbones.revnet import RevNet from official.vision.modeling.backbones.revnet import RevNet
from official.vision.modeling.backbones.spinenet import SpineNet from official.vision.modeling.backbones.spinenet import SpineNet
from official.vision.modeling.backbones.spinenet_mobile import SpineNetMobile from official.vision.modeling.backbones.spinenet_mobile import SpineNetMobile
from official.vision.modeling.backbones.vit import VisionTransformer
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer models."""
from typing import Optional, Tuple
from absl import logging
import tensorflow as tf
from official.modeling import activations
from official.vision.modeling.backbones import factory
from official.vision.modeling.backbones.vit_specs import VIT_SPECS
from official.vision.modeling.layers import nn_blocks
from official.vision.modeling.layers import nn_layers
layers = tf.keras.layers
class AddPositionEmbs(tf.keras.layers.Layer):
"""Adds (optionally learned) positional embeddings to the inputs."""
def __init__(self,
posemb_init: Optional[tf.keras.initializers.Initializer] = None,
posemb_origin_shape: Optional[Tuple[int, int]] = None,
posemb_target_shape: Optional[Tuple[int, int]] = None,
**kwargs):
"""Constructs Postional Embedding module.
The logic of this module is: the learnable positional embeddings length will
be determined by the inputs_shape or posemb_origin_shape (if provided)
during the construction. If the posemb_target_shape is provided and is
different from the positional embeddings length, the embeddings will be
interpolated during the forward call.
Args:
posemb_init: The positional embedding initializer.
posemb_origin_shape: The intended positional embedding shape.
posemb_target_shape: The potential target shape positional embedding may
be interpolated to.
**kwargs: other args.
"""
super().__init__(**kwargs)
self.posemb_init = posemb_init
self.posemb_origin_shape = posemb_origin_shape
self.posemb_target_shape = posemb_target_shape
def build(self, inputs_shape):
if self.posemb_origin_shape is not None:
pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
else:
pos_emb_length = inputs_shape[1]
pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
self.pos_embedding = self.add_weight(
'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int],
to_shape: Tuple[int, int]) -> tf.Tensor:
"""Interpolates the positional embeddings."""
logging.info('Interpolating postional embedding from length: %d to %d',
from_shape, to_shape)
grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
# NOTE: Using BILINEAR interpolation by default.
grid_emb = tf.image.resize(grid_emb, to_shape)
return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])
def call(self, inputs, inputs_positions=None):
del inputs_positions
pos_embedding = self.pos_embedding
# inputs.shape is (batch_size, seq_len, emb_dim).
if inputs.shape[1] != pos_embedding.shape[1]:
pos_embedding = self._interpolate(
pos_embedding,
from_shape=self.posemb_origin_shape,
to_shape=self.posemb_target_shape)
pos_embedding = tf.cast(pos_embedding, inputs.dtype)
return inputs + pos_embedding
class TokenLayer(tf.keras.layers.Layer):
"""A simple layer to wrap token parameters."""
def build(self, inputs_shape):
self.cls = self.add_weight(
'cls', (1, 1, inputs_shape[-1]), initializer='zeros')
def call(self, inputs):
cls = tf.cast(self.cls, inputs.dtype)
cls = cls + tf.zeros_like(inputs[:, 0:1]) # A hacky way to tile.
x = tf.concat([cls, inputs], axis=1)
return x
class Encoder(tf.keras.layers.Layer):
"""Transformer Encoder."""
def __init__(self,
num_layers,
mlp_dim,
num_heads,
dropout_rate=0.1,
attention_dropout_rate=0.1,
kernel_regularizer=None,
inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
add_pos_embed=True,
pos_embed_origin_shape=None,
pos_embed_target_shape=None,
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
self._mlp_dim = mlp_dim
self._num_heads = num_heads
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed
self._pos_embed_origin_shape = pos_embed_origin_shape
self._pos_embed_target_shape = pos_embed_target_shape
def build(self, input_shape):
if self._add_pos_embed:
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
posemb_origin_shape=self._pos_embed_origin_shape,
posemb_target_shape=self._pos_embed_target_shape,
name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate)
self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
for i in range(self._num_layers):
encoder_layer = nn_blocks.TransformerEncoderBlock(
inner_activation=activations.gelu,
num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
norm_first=True,
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 1, self._num_layers),
norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6)
super().build(input_shape)
def call(self, inputs, training=None):
x = inputs
if self._add_pos_embed:
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
x = self._dropout(x, training=training)
for encoder_layer in self._encoder_layers:
x = encoder_layer(x, training=training)
x = self._norm(x)
return x
def get_config(self):
config = super().get_config()
updates = {
'num_layers': self._num_layers,
'mlp_dim': self._mlp_dim,
'num_heads': self._num_heads,
'dropout_rate': self._dropout_rate,
'attention_dropout_rate': self._attention_dropout_rate,
'kernel_regularizer': self._kernel_regularizer,
'inputs_positions': self._inputs_positions,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'kernel_initializer': self._kernel_initializer,
'add_pos_embed': self._add_pos_embed,
'pos_embed_origin_shape': self._pos_embed_origin_shape,
'pos_embed_target_shape': self._pos_embed_target_shape,
}
config.update(updates)
return config
class VisionTransformer(tf.keras.Model):
"""Class to build VisionTransformer family model."""
def __init__(self,
mlp_dim=3072,
num_heads=12,
num_layers=12,
attention_dropout_rate=0.0,
dropout_rate=0.1,
init_stochastic_depth_rate=0.0,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16,
hidden_size=768,
representation_size=0,
pooler='token',
kernel_regularizer=None,
original_init: bool = True,
pos_embed_shape: Optional[Tuple[int, int]] = None):
"""VisionTransformer initialization function."""
self._mlp_dim = mlp_dim
self._num_heads = num_heads
self._num_layers = num_layers
self._hidden_size = hidden_size
self._patch_size = patch_size
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = layers.Conv2D(
filters=hidden_size,
kernel_size=patch_size,
strides=patch_size,
padding='valid',
kernel_regularizer=kernel_regularizer,
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
inputs)
if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2)
else:
rows_axis, cols_axis = (2, 3)
# The reshape below assumes the data_format is 'channels_last,' so
# transpose to that. Once the data is flattened by the reshape, the
# data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format.
x = tf.transpose(x, perm=[0, 2, 3, 1])
pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
seq_len = (input_specs.shape[rows_axis] // patch_size) * (
input_specs.shape[cols_axis] // patch_size)
x = tf.reshape(x, [-1, seq_len, hidden_size])
# If we want to add a class token, add it here.
if pooler == 'token':
x = TokenLayer(name='cls')(x)
x = Encoder(
num_layers=num_layers,
mlp_dim=mlp_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate,
pos_embed_origin_shape=pos_embed_shape,
pos_embed_target_shape=pos_embed_target_shape)(
x)
if pooler == 'token':
x = x[:, 0]
elif pooler == 'gap':
x = tf.reduce_mean(x, axis=1)
elif pooler == 'none':
x = tf.identity(x, name='encoded_tokens')
else:
raise ValueError(f'unrecognized pooler type: {pooler}')
if representation_size:
x = tf.keras.layers.Dense(
representation_size,
kernel_regularizer=kernel_regularizer,
name='pre_logits',
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
x)
x = tf.nn.tanh(x)
else:
x = tf.identity(x, name='pre_logits')
if pooler == 'none':
endpoints = {'encoded_tokens': x}
else:
endpoints = {
'pre_logits':
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
}
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit')
def build_vit(input_specs,
backbone_config,
norm_activation_config,
l2_regularizer=None):
"""Build ViT model."""
del norm_activation_config
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'vit', (f'Inconsistent backbone type '
f'{backbone_type}')
backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])
return VisionTransformer(
mlp_dim=backbone_cfg.transformer.mlp_dim,
num_heads=backbone_cfg.transformer.num_heads,
num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate,
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
input_specs=input_specs,
patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size,
pooler=backbone_cfg.pooler,
kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init,
pos_embed_shape=backbone_cfg.pos_embed_shape)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer backbone specs."""
import immutabledict
VIT_SPECS = immutabledict.immutabledict({
'vit-ti16':
dict(
hidden_size=192,
patch_size=16,
transformer=dict(mlp_dim=768, num_heads=3, num_layers=12),
),
'vit-s16':
dict(
hidden_size=384,
patch_size=16,
transformer=dict(mlp_dim=1536, num_heads=6, num_layers=12),
),
'vit-b16':
dict(
hidden_size=768,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-b32':
dict(
hidden_size=768,
patch_size=32,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-l16':
dict(
hidden_size=1024,
patch_size=16,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-l32':
dict(
hidden_size=1024,
patch_size=32,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-h14':
dict(
hidden_size=1280,
patch_size=14,
transformer=dict(mlp_dim=5120, num_heads=16, num_layers=32),
),
'vit-g14':
dict(
hidden_size=1664,
patch_size=14,
transformer=dict(mlp_dim=8192, num_heads=16, num_layers=48),
),
})
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for VIT."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.modeling.backbones import vit
class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(224, 85798656),
(256, 85844736),
)
def test_network_creation(self, input_size, params_count):
"""Test creation of VisionTransformer family models."""
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(input_specs=input_specs)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
self.assertEqual(network.count_params(), params_count)
def test_network_none_pooler(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='none',
representation_size=128,
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['encoded_tokens']
self.assertEqual(output.shape, [1, 256, 128])
def test_posembedding_interpolation(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='gap',
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['pre_logits']
self.assertEqual(output.shape, [1, 1, 1, 768])
if __name__ == '__main__':
tf.test.main()
...@@ -27,20 +27,49 @@ from official.vision.modeling import classification_model ...@@ -27,20 +27,49 @@ from official.vision.modeling import classification_model
class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(192 * 4, 3, 12, 192, 5524416),
(384 * 4, 6, 12, 384, 21665664),
)
def test_vision_transformer_creation(self, mlp_dim, num_heads, num_layers,
hidden_size, num_params):
"""Test for creation of a Vision Transformer classifier."""
inputs = np.random.rand(2, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.VisionTransformer(
mlp_dim=mlp_dim,
num_heads=num_heads,
num_layers=num_layers,
hidden_size=hidden_size,
input_specs=tf.keras.layers.InputSpec(shape=[None, 224, 224, 3]),
)
self.assertEqual(backbone.count_params(), num_params)
num_classes = 1000
model = classification_model.ClassificationModel(
backbone=backbone,
num_classes=num_classes,
dropout_rate=0.2,
)
logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape)
@parameterized.parameters( @parameterized.parameters(
(128, 50, 'relu'), (128, 50, 'relu'),
(128, 50, 'relu'), (128, 50, 'relu'),
(128, 50, 'swish'), (128, 50, 'swish'),
) )
def test_resnet_network_creation( def test_resnet_network_creation(self, input_size, resnet_model_id,
self, input_size, resnet_model_id, activation): activation):
"""Test for creation of a ResNet-50 classifier.""" """Test for creation of a ResNet-50 classifier."""
inputs = np.random.rand(2, input_size, input_size, 3) inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.ResNet( backbone = backbones.ResNet(model_id=resnet_model_id, activation=activation)
model_id=resnet_model_id, activation=activation)
self.assertEqual(backbone.count_params(), 23561152) self.assertEqual(backbone.count_params(), 23561152)
num_classes = 1000 num_classes = 1000
......
...@@ -21,6 +21,7 @@ from absl import logging ...@@ -21,6 +21,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import modeling as nlp_modeling
from official.vision.modeling.layers import nn_layers from official.vision.modeling.layers import nn_layers
...@@ -538,8 +539,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -538,8 +539,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
se_inner_activation: A `str` name of squeeze-excitation inner activation. se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_gating_activation: A `str` name of squeeze-excitation gating se_gating_activation: A `str` name of squeeze-excitation gating
activation. activation.
se_round_down_protect: A `bool` of whether round down more than 10% se_round_down_protect: A `bool` of whether round down more than 10% will
will be allowed in SE layer. be allowed in SE layer.
expand_se_in_filters: A `bool` of whether or not to expand in_filter in expand_se_in_filters: A `bool` of whether or not to expand in_filter in
squeeze and excitation layer. squeeze and excitation layer.
depthwise_activation: A `str` name of the activation function for depthwise_activation: A `str` name of the activation function for
...@@ -547,9 +548,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -547,9 +548,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
dilation_rate: An `int` that specifies the dilation rate to use for. dilation_rate: An `int` that specifies the dilation rate to use for.
divisible_by: An `int` that ensures all inner dimensions are divisible by divisible_by: An `int` that ensures all inner dimensions are divisible by
this number. this number. dilated convolution: An `int` to specify the same value for
dilated convolution: An `int` to specify the same value for all spatial all spatial dimensions.
dimensions.
regularize_depthwise: A `bool` of whether or not apply regularization on regularize_depthwise: A `bool` of whether or not apply regularization on
depthwise. depthwise.
use_depthwise: A `bool` of whether to uses fused convolutions instead of use_depthwise: A `bool` of whether to uses fused convolutions instead of
...@@ -1204,7 +1204,8 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1204,7 +1204,8 @@ class ReversibleLayer(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
"""Creates an depthwise separable convolution block with batch normalization.""" """Creates a depthwise separable convolution block with batch normalization.
"""
def __init__( def __init__(
self, self,
...@@ -1354,10 +1355,10 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1354,10 +1355,10 @@ class TuckerConvBlock(tf.keras.layers.Layer):
Args: Args:
in_filters: An `int` number of filters of the input tensor. in_filters: An `int` number of filters of the input tensor.
out_filters: An `int` number of filters of the output tensor. out_filters: An `int` number of filters of the output tensor.
input_compression_ratio: An `float` of compression ratio for input_compression_ratio: An `float` of compression ratio for input
input filters. filters.
output_compression_ratio: An `float` of compression ratio for output_compression_ratio: An `float` of compression ratio for output
output filters. filters.
strides: An `int` block stride. If greater than 1, this block will strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input. ultimately downsample the input.
kernel_size: An `int` kernel_size of the depthwise conv layer. kernel_size: An `int` kernel_size of the depthwise conv layer.
...@@ -1510,11 +1511,114 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1510,11 +1511,114 @@ class TuckerConvBlock(tf.keras.layers.Layer):
x = self._conv2(x) x = self._conv2(x)
x = self._norm2(x) x = self._norm2(x)
if (self._use_residual and if (self._use_residual and self._in_filters == self._out_filters and
self._in_filters == self._out_filters and
self._strides == 1): self._strides == 1):
if self._stochastic_depth: if self._stochastic_depth:
x = self._stochastic_depth(x, training=training) x = self._stochastic_depth(x, training=training)
x = self._add([x, shortcut]) x = self._add([x, shortcut])
return x return x
class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth."""
def __init__(self,
*args,
stochastic_depth_drop_rate=0.0,
return_attention=False,
**kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention = return_attention
def build(self, input_shape):
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
super().build(input_shape)
def get_config(self):
config = {'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
"""Transformer self-attention encoder block call."""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output, attention_scores = self._attention_layer(
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
return_attention_scores=True)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + self._stochastic_depth(
attention_output, training=training)
else:
attention_output = self._attention_layer_norm(
target_tensor +
self._stochastic_depth(attention_output, training=training))
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
if self._return_attention:
return source_attention_output + self._stochastic_depth(
layer_output, training=training), attention_scores
else:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
if self._return_attention:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training)), attention_scores
else:
return self._output_layer_norm(
layer_output +
self._stochastic_depth(attention_output, training=training))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment