Commit 90585434 authored by Chaochao Yan's avatar Chaochao Yan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 481733792
parent d309fff8
...@@ -75,7 +75,7 @@ task_factory.register_task_cls(ImageClassificationTask)( ...@@ -75,7 +75,7 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification.ImageClassificationTask) image_classification.ImageClassificationTask)
@exp_factory.register_config_factory('deit_imagenet_pretrain') @exp_factory.register_config_factory('legacy_deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig: def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer.""" """Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32 train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
...@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig: ...@@ -156,7 +156,7 @@ def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
return config return config
@exp_factory.register_config_factory('vit_imagenet_pretrain') @exp_factory.register_config_factory('legacy_vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer.""" """Image classification on imagenet with vision transformer."""
train_batch_size = 4096 train_batch_size = 4096
...@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: ...@@ -220,7 +220,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
return config return config
@exp_factory.register_config_factory('vit_imagenet_finetune') @exp_factory.register_config_factory('legacy_vit_imagenet_finetune')
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig: def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer.""" """Image classification on imagenet with vision transformer."""
train_batch_size = 512 train_batch_size = 512
......
...@@ -294,7 +294,7 @@ class VisionTransformer(tf.keras.Model): ...@@ -294,7 +294,7 @@ class VisionTransformer(tf.keras.Model):
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints) super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit') @factory.register_backbone_builder('legacy_vit')
def build_vit(input_specs, def build_vit(input_specs,
backbone_config, backbone_config,
norm_activation_config, norm_activation_config,
......
...@@ -14,13 +14,37 @@ ...@@ -14,13 +14,37 @@
"""Backbones configurations.""" """Backbones configurations."""
import dataclasses import dataclasses
from typing import Optional, List from typing import List, Optional, Tuple
# Import libraries
from official.modeling import hyperparams from official.modeling import hyperparams
@dataclasses.dataclass
class Transformer(hyperparams.Config):
"""Transformer config."""
mlp_dim: int = 1
num_heads: int = 1
num_layers: int = 1
attention_dropout_rate: float = 0.0
dropout_rate: float = 0.1
@dataclasses.dataclass
class VisionTransformer(hyperparams.Config):
"""VisionTransformer config."""
model_name: str = 'vit-b16'
# pylint: disable=line-too-long
pooler: str = 'token' # 'token', 'gap' or 'none'. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size: int = 0
hidden_size: int = 1
patch_size: int = 16
transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
pos_embed_shape: Optional[Tuple[int, int]] = None
@dataclasses.dataclass @dataclasses.dataclass
class ResNet(hyperparams.Config): class ResNet(hyperparams.Config):
"""ResNet config.""" """ResNet config."""
...@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig): ...@@ -120,6 +144,7 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile: mobile spinenet backbone config. spinenet_mobile: mobile spinenet backbone config.
mobilenet: mobilenet backbone config. mobilenet: mobilenet backbone config.
mobiledet: mobiledet backbone config. mobiledet: mobiledet backbone config.
vit: vision transformer backbone config.
""" """
type: Optional[str] = None type: Optional[str] = None
resnet: ResNet = ResNet() resnet: ResNet = ResNet()
...@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig): ...@@ -130,4 +155,4 @@ class Backbone(hyperparams.OneOfConfig):
spinenet_mobile: SpineNetMobile = SpineNetMobile() spinenet_mobile: SpineNetMobile = SpineNetMobile()
mobilenet: MobileNet = MobileNet() mobilenet: MobileNet = MobileNet()
mobiledet: MobileDet = MobileDet() mobiledet: MobileDet = MobileDet()
vit: VisionTransformer = VisionTransformer()
...@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig: ...@@ -402,3 +402,201 @@ def image_classification_imagenet_mobilenet() -> cfg.ExperimentConfig:
]) ])
return config return config
@exp_factory.register_config_factory('deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(
l2_weight_decay=0.0,
label_smoothing=label_smoothing,
one_hot=False,
soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_type=common.Augmentation(
type='randaug',
randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
label_smoothing=label_smoothing)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16', representation_size=768))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003 * train_batch_size / 4096,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 10000,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_finetune')
def image_classification_imagenet_vit_finetune() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 512
eval_batch_size = 512
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[384, 384, 3],
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(model_name='vit-b16'))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=20000,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'sgd',
'sgd': {
'momentum': 0.9,
'global_clipnorm': 1.0,
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003,
'decay_steps': 20000,
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
...@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase): ...@@ -29,7 +29,10 @@ class ImageClassificationConfigTest(tf.test.TestCase, parameterized.TestCase):
('resnet_imagenet',), ('resnet_imagenet',),
('resnet_rs_imagenet',), ('resnet_rs_imagenet',),
('revnet_imagenet',), ('revnet_imagenet',),
('mobilenet_imagenet'), ('mobilenet_imagenet',),
('deit_imagenet_pretrain',),
('vit_imagenet_pretrain',),
('vit_imagenet_finetune',),
) )
def test_image_classification_configs(self, config_name): def test_image_classification_configs(self, config_name):
config = exp_factory.get_exp_config(config_name) config = exp_factory.get_exp_config(config_name)
......
...@@ -23,3 +23,4 @@ from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet ...@@ -23,3 +23,4 @@ from official.vision.modeling.backbones.resnet_deeplab import DilatedResNet
from official.vision.modeling.backbones.revnet import RevNet from official.vision.modeling.backbones.revnet import RevNet
from official.vision.modeling.backbones.spinenet import SpineNet from official.vision.modeling.backbones.spinenet import SpineNet
from official.vision.modeling.backbones.spinenet_mobile import SpineNetMobile from official.vision.modeling.backbones.spinenet_mobile import SpineNetMobile
from official.vision.modeling.backbones.vit import VisionTransformer
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer models."""
from typing import Optional, Tuple
from absl import logging
import tensorflow as tf
from official.modeling import activations
from official.vision.modeling.backbones import factory
from official.vision.modeling.backbones.vit_specs import VIT_SPECS
from official.vision.modeling.layers import nn_blocks
from official.vision.modeling.layers import nn_layers
layers = tf.keras.layers
class AddPositionEmbs(tf.keras.layers.Layer):
"""Adds (optionally learned) positional embeddings to the inputs."""
def __init__(self,
posemb_init: Optional[tf.keras.initializers.Initializer] = None,
posemb_origin_shape: Optional[Tuple[int, int]] = None,
posemb_target_shape: Optional[Tuple[int, int]] = None,
**kwargs):
"""Constructs Postional Embedding module.
The logic of this module is: the learnable positional embeddings length will
be determined by the inputs_shape or posemb_origin_shape (if provided)
during the construction. If the posemb_target_shape is provided and is
different from the positional embeddings length, the embeddings will be
interpolated during the forward call.
Args:
posemb_init: The positional embedding initializer.
posemb_origin_shape: The intended positional embedding shape.
posemb_target_shape: The potential target shape positional embedding may
be interpolated to.
**kwargs: other args.
"""
super().__init__(**kwargs)
self.posemb_init = posemb_init
self.posemb_origin_shape = posemb_origin_shape
self.posemb_target_shape = posemb_target_shape
def build(self, inputs_shape):
if self.posemb_origin_shape is not None:
pos_emb_length = self.posemb_origin_shape[0] * self.posemb_origin_shape[1]
else:
pos_emb_length = inputs_shape[1]
pos_emb_shape = (1, pos_emb_length, inputs_shape[2])
self.pos_embedding = self.add_weight(
'pos_embedding', pos_emb_shape, initializer=self.posemb_init)
def _interpolate(self, pos_embedding: tf.Tensor, from_shape: Tuple[int, int],
to_shape: Tuple[int, int]) -> tf.Tensor:
"""Interpolates the positional embeddings."""
logging.info('Interpolating postional embedding from length: %d to %d',
from_shape, to_shape)
grid_emb = tf.reshape(pos_embedding, [1] + list(from_shape) + [-1])
# NOTE: Using BILINEAR interpolation by default.
grid_emb = tf.image.resize(grid_emb, to_shape)
return tf.reshape(grid_emb, [1, to_shape[0] * to_shape[1], -1])
def call(self, inputs, inputs_positions=None):
del inputs_positions
pos_embedding = self.pos_embedding
# inputs.shape is (batch_size, seq_len, emb_dim).
if inputs.shape[1] != pos_embedding.shape[1]:
pos_embedding = self._interpolate(
pos_embedding,
from_shape=self.posemb_origin_shape,
to_shape=self.posemb_target_shape)
pos_embedding = tf.cast(pos_embedding, inputs.dtype)
return inputs + pos_embedding
class TokenLayer(tf.keras.layers.Layer):
"""A simple layer to wrap token parameters."""
def build(self, inputs_shape):
self.cls = self.add_weight(
'cls', (1, 1, inputs_shape[-1]), initializer='zeros')
def call(self, inputs):
cls = tf.cast(self.cls, inputs.dtype)
cls = cls + tf.zeros_like(inputs[:, 0:1]) # A hacky way to tile.
x = tf.concat([cls, inputs], axis=1)
return x
class Encoder(tf.keras.layers.Layer):
"""Transformer Encoder."""
def __init__(self,
num_layers,
mlp_dim,
num_heads,
dropout_rate=0.1,
attention_dropout_rate=0.1,
kernel_regularizer=None,
inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
add_pos_embed=True,
pos_embed_origin_shape=None,
pos_embed_target_shape=None,
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
self._mlp_dim = mlp_dim
self._num_heads = num_heads
self._dropout_rate = dropout_rate
self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
self._add_pos_embed = add_pos_embed
self._pos_embed_origin_shape = pos_embed_origin_shape
self._pos_embed_target_shape = pos_embed_target_shape
def build(self, input_shape):
if self._add_pos_embed:
self._pos_embed = AddPositionEmbs(
posemb_init=tf.keras.initializers.RandomNormal(stddev=0.02),
posemb_origin_shape=self._pos_embed_origin_shape,
posemb_target_shape=self._pos_embed_target_shape,
name='posembed_input')
self._dropout = layers.Dropout(rate=self._dropout_rate)
self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.deprecated.nn.LayerNorm.html
for i in range(self._num_layers):
encoder_layer = nn_blocks.TransformerEncoderBlock(
inner_activation=activations.gelu,
num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
norm_first=True,
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 1, self._num_layers),
norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6)
super().build(input_shape)
def call(self, inputs, training=None):
x = inputs
if self._add_pos_embed:
x = self._pos_embed(x, inputs_positions=self._inputs_positions)
x = self._dropout(x, training=training)
for encoder_layer in self._encoder_layers:
x = encoder_layer(x, training=training)
x = self._norm(x)
return x
def get_config(self):
config = super().get_config()
updates = {
'num_layers': self._num_layers,
'mlp_dim': self._mlp_dim,
'num_heads': self._num_heads,
'dropout_rate': self._dropout_rate,
'attention_dropout_rate': self._attention_dropout_rate,
'kernel_regularizer': self._kernel_regularizer,
'inputs_positions': self._inputs_positions,
'init_stochastic_depth_rate': self._init_stochastic_depth_rate,
'kernel_initializer': self._kernel_initializer,
'add_pos_embed': self._add_pos_embed,
'pos_embed_origin_shape': self._pos_embed_origin_shape,
'pos_embed_target_shape': self._pos_embed_target_shape,
}
config.update(updates)
return config
class VisionTransformer(tf.keras.Model):
"""Class to build VisionTransformer family model."""
def __init__(self,
mlp_dim=3072,
num_heads=12,
num_layers=12,
attention_dropout_rate=0.0,
dropout_rate=0.1,
init_stochastic_depth_rate=0.0,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16,
hidden_size=768,
representation_size=0,
pooler='token',
kernel_regularizer=None,
original_init: bool = True,
pos_embed_shape: Optional[Tuple[int, int]] = None):
"""VisionTransformer initialization function."""
self._mlp_dim = mlp_dim
self._num_heads = num_heads
self._num_layers = num_layers
self._hidden_size = hidden_size
self._patch_size = patch_size
inputs = tf.keras.Input(shape=input_specs.shape[1:])
x = layers.Conv2D(
filters=hidden_size,
kernel_size=patch_size,
strides=patch_size,
padding='valid',
kernel_regularizer=kernel_regularizer,
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
inputs)
if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2)
else:
rows_axis, cols_axis = (2, 3)
# The reshape below assumes the data_format is 'channels_last,' so
# transpose to that. Once the data is flattened by the reshape, the
# data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format.
x = tf.transpose(x, perm=[0, 2, 3, 1])
pos_embed_target_shape = (x.shape[rows_axis], x.shape[cols_axis])
seq_len = (input_specs.shape[rows_axis] // patch_size) * (
input_specs.shape[cols_axis] // patch_size)
x = tf.reshape(x, [-1, seq_len, hidden_size])
# If we want to add a class token, add it here.
if pooler == 'token':
x = TokenLayer(name='cls')(x)
x = Encoder(
num_layers=num_layers,
mlp_dim=mlp_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate,
pos_embed_origin_shape=pos_embed_shape,
pos_embed_target_shape=pos_embed_target_shape)(
x)
if pooler == 'token':
x = x[:, 0]
elif pooler == 'gap':
x = tf.reduce_mean(x, axis=1)
elif pooler == 'none':
x = tf.identity(x, name='encoded_tokens')
else:
raise ValueError(f'unrecognized pooler type: {pooler}')
if representation_size:
x = tf.keras.layers.Dense(
representation_size,
kernel_regularizer=kernel_regularizer,
name='pre_logits',
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
x)
x = tf.nn.tanh(x)
else:
x = tf.identity(x, name='pre_logits')
if pooler == 'none':
endpoints = {'encoded_tokens': x}
else:
endpoints = {
'pre_logits':
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
}
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit')
def build_vit(input_specs,
backbone_config,
norm_activation_config,
l2_regularizer=None):
"""Build ViT model."""
del norm_activation_config
backbone_type = backbone_config.type
backbone_cfg = backbone_config.get()
assert backbone_type == 'vit', (f'Inconsistent backbone type '
f'{backbone_type}')
backbone_cfg.override(VIT_SPECS[backbone_cfg.model_name])
return VisionTransformer(
mlp_dim=backbone_cfg.transformer.mlp_dim,
num_heads=backbone_cfg.transformer.num_heads,
num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate,
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
input_specs=input_specs,
patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size,
pooler=backbone_cfg.pooler,
kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init,
pos_embed_shape=backbone_cfg.pos_embed_shape)
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""VisionTransformer backbone specs."""
import immutabledict
VIT_SPECS = immutabledict.immutabledict({
'vit-ti16':
dict(
hidden_size=192,
patch_size=16,
transformer=dict(mlp_dim=768, num_heads=3, num_layers=12),
),
'vit-s16':
dict(
hidden_size=384,
patch_size=16,
transformer=dict(mlp_dim=1536, num_heads=6, num_layers=12),
),
'vit-b16':
dict(
hidden_size=768,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-b32':
dict(
hidden_size=768,
patch_size=32,
transformer=dict(mlp_dim=3072, num_heads=12, num_layers=12),
),
'vit-l16':
dict(
hidden_size=1024,
patch_size=16,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-l32':
dict(
hidden_size=1024,
patch_size=32,
transformer=dict(mlp_dim=4096, num_heads=16, num_layers=24),
),
'vit-h14':
dict(
hidden_size=1280,
patch_size=14,
transformer=dict(mlp_dim=5120, num_heads=16, num_layers=32),
),
'vit-g14':
dict(
hidden_size=1664,
patch_size=14,
transformer=dict(mlp_dim=8192, num_heads=16, num_layers=48),
),
})
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for VIT."""
from absl.testing import parameterized
import tensorflow as tf
from official.vision.modeling.backbones import vit
class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(224, 85798656),
(256, 85844736),
)
def test_network_creation(self, input_size, params_count):
"""Test creation of VisionTransformer family models."""
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(input_specs=input_specs)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
_ = network(inputs)
self.assertEqual(network.count_params(), params_count)
def test_network_none_pooler(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='none',
representation_size=128,
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['encoded_tokens']
self.assertEqual(output.shape, [1, 256, 128])
def test_posembedding_interpolation(self):
tf.keras.backend.set_image_data_format('channels_last')
input_size = 256
input_specs = tf.keras.layers.InputSpec(
shape=[2, input_size, input_size, 3])
network = vit.VisionTransformer(
input_specs=input_specs,
patch_size=16,
pooler='gap',
pos_embed_shape=(14, 14)) # (224 // 16)
inputs = tf.keras.Input(shape=(input_size, input_size, 3), batch_size=1)
output = network(inputs)['pre_logits']
self.assertEqual(output.shape, [1, 1, 1, 768])
if __name__ == '__main__':
tf.test.main()
...@@ -27,20 +27,49 @@ from official.vision.modeling import classification_model ...@@ -27,20 +27,49 @@ from official.vision.modeling import classification_model
class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase): class ClassificationNetworkTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(192 * 4, 3, 12, 192, 5524416),
(384 * 4, 6, 12, 384, 21665664),
)
def test_vision_transformer_creation(self, mlp_dim, num_heads, num_layers,
hidden_size, num_params):
"""Test for creation of a Vision Transformer classifier."""
inputs = np.random.rand(2, 224, 224, 3)
tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.VisionTransformer(
mlp_dim=mlp_dim,
num_heads=num_heads,
num_layers=num_layers,
hidden_size=hidden_size,
input_specs=tf.keras.layers.InputSpec(shape=[None, 224, 224, 3]),
)
self.assertEqual(backbone.count_params(), num_params)
num_classes = 1000
model = classification_model.ClassificationModel(
backbone=backbone,
num_classes=num_classes,
dropout_rate=0.2,
)
logits = model(inputs)
self.assertAllEqual([2, num_classes], logits.numpy().shape)
@parameterized.parameters( @parameterized.parameters(
(128, 50, 'relu'), (128, 50, 'relu'),
(128, 50, 'relu'), (128, 50, 'relu'),
(128, 50, 'swish'), (128, 50, 'swish'),
) )
def test_resnet_network_creation( def test_resnet_network_creation(self, input_size, resnet_model_id,
self, input_size, resnet_model_id, activation): activation):
"""Test for creation of a ResNet-50 classifier.""" """Test for creation of a ResNet-50 classifier."""
inputs = np.random.rand(2, input_size, input_size, 3) inputs = np.random.rand(2, input_size, input_size, 3)
tf.keras.backend.set_image_data_format('channels_last') tf.keras.backend.set_image_data_format('channels_last')
backbone = backbones.ResNet( backbone = backbones.ResNet(model_id=resnet_model_id, activation=activation)
model_id=resnet_model_id, activation=activation)
self.assertEqual(backbone.count_params(), 23561152) self.assertEqual(backbone.count_params(), 23561152)
num_classes = 1000 num_classes = 1000
......
...@@ -21,6 +21,7 @@ from absl import logging ...@@ -21,6 +21,7 @@ from absl import logging
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.nlp import modeling as nlp_modeling
from official.vision.modeling.layers import nn_layers from official.vision.modeling.layers import nn_layers
...@@ -538,8 +539,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -538,8 +539,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
se_inner_activation: A `str` name of squeeze-excitation inner activation. se_inner_activation: A `str` name of squeeze-excitation inner activation.
se_gating_activation: A `str` name of squeeze-excitation gating se_gating_activation: A `str` name of squeeze-excitation gating
activation. activation.
se_round_down_protect: A `bool` of whether round down more than 10% se_round_down_protect: A `bool` of whether round down more than 10% will
will be allowed in SE layer. be allowed in SE layer.
expand_se_in_filters: A `bool` of whether or not to expand in_filter in expand_se_in_filters: A `bool` of whether or not to expand in_filter in
squeeze and excitation layer. squeeze and excitation layer.
depthwise_activation: A `str` name of the activation function for depthwise_activation: A `str` name of the activation function for
...@@ -547,9 +548,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer): ...@@ -547,9 +548,8 @@ class InvertedBottleneckBlock(tf.keras.layers.Layer):
use_sync_bn: A `bool`. If True, use synchronized batch normalization. use_sync_bn: A `bool`. If True, use synchronized batch normalization.
dilation_rate: An `int` that specifies the dilation rate to use for. dilation_rate: An `int` that specifies the dilation rate to use for.
divisible_by: An `int` that ensures all inner dimensions are divisible by divisible_by: An `int` that ensures all inner dimensions are divisible by
this number. this number. dilated convolution: An `int` to specify the same value for
dilated convolution: An `int` to specify the same value for all spatial all spatial dimensions.
dimensions.
regularize_depthwise: A `bool` of whether or not apply regularization on regularize_depthwise: A `bool` of whether or not apply regularization on
depthwise. depthwise.
use_depthwise: A `bool` of whether to uses fused convolutions instead of use_depthwise: A `bool` of whether to uses fused convolutions instead of
...@@ -1048,7 +1048,7 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1048,7 +1048,7 @@ class ReversibleLayer(tf.keras.layers.Layer):
(bottleneck) residual functions. Where the input to the reversible layer (bottleneck) residual functions. Where the input to the reversible layer
is x, the input gets partitioned in the channel dimension and the is x, the input gets partitioned in the channel dimension and the
forward pass follows (eq8): x = [x1; x2], z1 = x1 + f(x2), y2 = x2 + forward pass follows (eq8): x = [x1; x2], z1 = x1 + f(x2), y2 = x2 +
g(z1), y1 = stop_gradient(z1). g(z1), y1 = stop_gradient(z1).
g: A `tf.keras.layers.Layer` instance of `g` inner block referred to in g: A `tf.keras.layers.Layer` instance of `g` inner block referred to in
paper. Detailed explanation same as above as `f` arg. paper. Detailed explanation same as above as `f` arg.
manual_grads: A `bool` [Testing Only] of whether to manually take manual_grads: A `bool` [Testing Only] of whether to manually take
...@@ -1204,7 +1204,8 @@ class ReversibleLayer(tf.keras.layers.Layer): ...@@ -1204,7 +1204,8 @@ class ReversibleLayer(tf.keras.layers.Layer):
@tf.keras.utils.register_keras_serializable(package='Vision') @tf.keras.utils.register_keras_serializable(package='Vision')
class DepthwiseSeparableConvBlock(tf.keras.layers.Layer): class DepthwiseSeparableConvBlock(tf.keras.layers.Layer):
"""Creates an depthwise separable convolution block with batch normalization.""" """Creates a depthwise separable convolution block with batch normalization.
"""
def __init__( def __init__(
self, self,
...@@ -1354,10 +1355,10 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1354,10 +1355,10 @@ class TuckerConvBlock(tf.keras.layers.Layer):
Args: Args:
in_filters: An `int` number of filters of the input tensor. in_filters: An `int` number of filters of the input tensor.
out_filters: An `int` number of filters of the output tensor. out_filters: An `int` number of filters of the output tensor.
input_compression_ratio: An `float` of compression ratio for input_compression_ratio: An `float` of compression ratio for input
input filters. filters.
output_compression_ratio: An `float` of compression ratio for output_compression_ratio: An `float` of compression ratio for output
output filters. filters.
strides: An `int` block stride. If greater than 1, this block will strides: An `int` block stride. If greater than 1, this block will
ultimately downsample the input. ultimately downsample the input.
kernel_size: An `int` kernel_size of the depthwise conv layer. kernel_size: An `int` kernel_size of the depthwise conv layer.
...@@ -1510,11 +1511,114 @@ class TuckerConvBlock(tf.keras.layers.Layer): ...@@ -1510,11 +1511,114 @@ class TuckerConvBlock(tf.keras.layers.Layer):
x = self._conv2(x) x = self._conv2(x)
x = self._norm2(x) x = self._norm2(x)
if (self._use_residual and if (self._use_residual and self._in_filters == self._out_filters and
self._in_filters == self._out_filters and
self._strides == 1): self._strides == 1):
if self._stochastic_depth: if self._stochastic_depth:
x = self._stochastic_depth(x, training=training) x = self._stochastic_depth(x, training=training)
x = self._add([x, shortcut]) x = self._add([x, shortcut])
return x return x
class TransformerEncoderBlock(nlp_modeling.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth."""
def __init__(self,
*args,
stochastic_depth_drop_rate=0.0,
return_attention=False,
**kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
self._return_attention = return_attention
def build(self, input_shape):
if self._stochastic_depth_drop_rate:
self._stochastic_depth = nn_layers.StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
super().build(input_shape)
def get_config(self):
config = {'stochastic_depth_drop_rate': self._stochastic_depth_drop_rate}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
"""Transformer self-attention encoder block call."""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError('Unexpected inputs to %s with length at %d' %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output, attention_scores = self._attention_layer(
query=target_tensor,
value=key_value,
attention_mask=attention_mask,
return_attention_scores=True)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + self._stochastic_depth(
attention_output, training=training)
else:
attention_output = self._attention_layer_norm(
target_tensor +
self._stochastic_depth(attention_output, training=training))
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
if self._return_attention:
return source_attention_output + self._stochastic_depth(
layer_output, training=training), attention_scores
else:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
if self._return_attention:
return self._output_layer_norm(layer_output + self._stochastic_depth(
attention_output, training=training)), attention_scores
else:
return self._output_layer_norm(
layer_output +
self._stochastic_depth(attention_output, training=training))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment