Commit aba78478 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 365713370
parent f3f3ec34
......@@ -39,6 +39,7 @@ class OptimizerConfig(oneof.OneOfConfig):
adamw: adam with weight decay.
lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
lars: lars optimizer.
"""
type: Optional[str] = None
sgd: opt_cfg.SGDConfig = opt_cfg.SGDConfig()
......@@ -46,6 +47,7 @@ class OptimizerConfig(oneof.OneOfConfig):
adamw: opt_cfg.AdamWeightDecayConfig = opt_cfg.AdamWeightDecayConfig()
lamb: opt_cfg.LAMBConfig = opt_cfg.LAMBConfig()
rmsprop: opt_cfg.RMSPropConfig = opt_cfg.RMSPropConfig()
lars: opt_cfg.LARSConfig = opt_cfg.LARSConfig()
@dataclasses.dataclass
......
......@@ -170,3 +170,38 @@ class EMAConfig(BaseOptimizerConfig):
average_decay: float = 0.99
start_step: int = 0
dynamic_decay: bool = True
@dataclasses.dataclass
class LARSConfig(BaseOptimizerConfig):
"""Layer-wise adaptive rate scaling config.
Attributes:
name: 'str', name of the optimizer.
momentum: `float` hyperparameter >= 0 that accelerates gradient descent
in the relevant direction and dampens oscillations. Defaults to 0.9.
eeta: `float` LARS coefficient as used in the paper. Default set to LARS
coefficient from the paper. (eeta / weight_decay) determines the
highest scaling factor in LARS..
weight_decay_rate: `float` for weight decay.
nesterov: 'boolean' for whether to use nesterov momentum.
classic_momentum: `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momentum update in
classic momentum, but after momentum for popular momentum.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['batch_normalization', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
"""
name: str = "LARS"
momentum: float = 0.9
eeta: float = 0.001
weight_decay_rate: float = 0.0
nesterov: bool = False
classic_momentum: bool = True
exclude_from_weight_decay: Optional[List[str]] = None
exclude_from_layer_adaptation: Optional[List[str]] = None
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layer-wise adaptive rate scaling optimizer."""
import re
from typing import Text, List, Optional
import tensorflow as tf
# pylint: disable=protected-access
class LARS(tf.keras.optimizers.Optimizer):
"""Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
"""
def __init__(self,
learning_rate: float = 0.01,
momentum: float = 0.9,
weight_decay_rate: float = 0.0,
eeta: float = 0.001,
nesterov: bool = False,
classic_momentum: bool = True,
exclude_from_weight_decay: Optional[List[Text]] = None,
exclude_from_layer_adaptation: Optional[List[Text]] = None,
name: Text = "LARS",
**kwargs):
"""Constructs a LARSOptimizer.
Args:
learning_rate: `float` for learning rate. Defaults to 0.01.
momentum: `float` hyperparameter >= 0 that accelerates gradient descent
in the relevant direction and dampens oscillations. Defaults to 0.9.
weight_decay_rate: `float` for weight decay.
eeta: `float` LARS coefficient as used in the paper. Default set to LARS
coefficient from the paper. (eeta / weight_decay) determines the
highest scaling factor in LARS..
nesterov: 'boolean' for whether to use nesterov momentum.
classic_momentum: `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momentum update in
classic momentum, but after momentum for popular momentum.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['batch_normalization', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
name: `Text` as optional name for the operations created when applying
gradients. Defaults to "LARS".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
gradients by value, `decay` is included for backward compatibility to
allow time inverse decay of learning rate. `lr` is included for
backward compatibility, recommended to use `learning_rate` instead.
"""
super(LARS, self).__init__(name, **kwargs)
self._set_hyper("learning_rate", learning_rate)
self._set_hyper("decay", self._initial_decay)
self.momentum = momentum
self.weight_decay_rate = weight_decay_rate
self.eeta = eeta
self.nesterov = nesterov
self.classic_momentum = classic_momentum
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay
def _create_slots(self, var_list):
for v in var_list:
self.add_slot(v, "momentum")
def _resource_apply_dense(self, grad, param, apply_state=None):
if grad is None or param is None:
return tf.no_op()
var_device, var_dtype = param.device, param.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))
learning_rate = coefficients["lr_t"]
param_name = param.name
v = self.get_slot(param, "momentum")
if self._use_weight_decay(param_name):
grad += self.weight_decay_rate * param
if self.classic_momentum:
trust_ratio = 1.0
if self._do_layer_adaptation(param_name):
w_norm = tf.norm(param, ord=2)
g_norm = tf.norm(grad, ord=2)
trust_ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0),
1.0)
scaled_lr = learning_rate * trust_ratio
next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
if self.nesterov:
update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
else:
update = next_v
next_param = param - update
else:
next_v = tf.multiply(self.momentum, v) + grad
if self.nesterov:
update = tf.multiply(self.momentum, next_v) + grad
else:
update = next_v
trust_ratio = 1.0
if self._do_layer_adaptation(param_name):
w_norm = tf.norm(param, ord=2)
v_norm = tf.norm(update, ord=2)
trust_ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm), 1.0),
1.0)
scaled_lr = trust_ratio * learning_rate
next_param = param - scaled_lr * update
return tf.group(*[
param.assign(next_param, use_locking=False),
v.assign(next_v, use_locking=False)
])
def _resource_apply_sparse(self, grad, handle, indices, apply_state):
raise NotImplementedError("Applying sparse gradients is not implemented.")
def _use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay_rate:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
if re.search(r, param_name) is not None:
return False
return True
def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
if re.search(r, param_name) is not None:
return False
return True
def get_config(self):
config = super(LARS, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"decay": self._serialize_hyperparameter("decay"),
"momentum": self.momentum,
"classic_momentum": self.classic_momentum,
"weight_decay_rate": self.weight_decay_rate,
"eeta": self.eeta,
"nesterov": self.nesterov,
})
return config
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
......@@ -21,6 +21,7 @@ import tensorflow as tf
import tensorflow_addons.optimizers as tfa_optimizers
from official.modeling.optimization import ema_optimizer
from official.modeling.optimization import lars_optimizer
from official.modeling.optimization import lr_schedule
from official.modeling.optimization.configs import optimization_config as opt_cfg
from official.nlp import optimization as nlp_optimization
......@@ -30,7 +31,8 @@ OPTIMIZERS_CLS = {
'adam': tf.keras.optimizers.Adam,
'adamw': nlp_optimization.AdamWeightDecay,
'lamb': tfa_optimizers.LAMB,
'rmsprop': tf.keras.optimizers.RMSprop
'rmsprop': tf.keras.optimizers.RMSprop,
'lars': lars_optimizer.LARS,
}
LR_CLS = {
......
......@@ -23,7 +23,9 @@ from official.modeling.optimization.configs import optimization_config
class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(('sgd'), ('rmsprop'), ('adam'), ('adamw'), ('lamb'))
@parameterized.parameters(('sgd'), ('rmsprop'),
('adam'), ('adamw'),
('lamb'), ('lars'))
def test_optimizers(self, optimizer_type):
params = {
'optimizer': {
......
# Simple Framework for Contrastive Learning
[![Paper](http://img.shields.io/badge/Paper-arXiv.2002.05709-B3181B?logo=arXiv)](https://arxiv.org/abs/2002.05709)
[![Paper](http://img.shields.io/badge/Paper-arXiv.2006.10029-B3181B?logo=arXiv)](https://arxiv.org/abs/2006.10029)
<div align="center">
<img width="50%" alt="SimCLR Illustration" src="https://1.bp.blogspot.com/--vH4PKpE9Yo/Xo4a2BYervI/AAAAAAAAFpM/vaFDwPXOyAokAC8Xh852DzOgEs22NhbXwCLcBGAsYHQ/s1600/image4.gif">
</div>
<div align="center">
An illustration of SimCLR (from <a href="https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html">our blog here</a>).
</div>
## Enviroment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[guide](https://www.tensorflow.org/guide/distributed_training) for an overview
of `tf.distribute`.
The code is compatible with TensorFlow 2.4+. See requirements.txt for all
prerequisites, and you can also install them using the following command. `pip
install -r ./official/requirements.txt`
## Pretraining
To pretrain the model on Imagenet, try the following command:
```
python3 -m official.vision.beta.projects.simclr.train \
--mode=train_and_eval \
--experiment=simclr_pretraining \
--model_dir={MODEL_DIR} \
--config_file={CONFIG_FILE}
```
An example of the config file can be found [here](./configs/experiments/imagenet_simclr_pretrain_gpu.yaml)
## Semi-supervised learning and fine-tuning the whole network
You can access 1% and 10% ImageNet subsets used for semi-supervised learning via
[tensorflow datasets](https://www.tensorflow.org/datasets/catalog/imagenet2012_subset).
You can also find image IDs of these subsets in `imagenet_subsets/`.
To fine-tune the whole network, refer to the following command:
```
python3 -m official.vision.beta.projects.simclr.train \
--mode=train_and_eval \
--experiment=simclr_finetuning \
--model_dir={MODEL_DIR} \
--config_file={CONFIG_FILE}
```
An example of the config file can be found [here](./configs/experiments/imagenet_simclr_finetune_gpu.yaml).
## Cite
[SimCLR paper](https://arxiv.org/abs/2002.05709):
```
@article{chen2020simple,
title={A Simple Framework for Contrastive Learning of Visual Representations},
author={Chen, Ting and Kornblith, Simon and Norouzi, Mohammad and Hinton, Geoffrey},
journal={arXiv preprint arXiv:2002.05709},
year={2020}
}
```
[SimCLRv2 paper](https://arxiv.org/abs/2006.10029):
```
@article{chen2020big,
title={Big Self-Supervised Models are Strong Semi-Supervised Learners},
author={Chen, Ting and Kornblith, Simon and Swersky, Kevin and Norouzi, Mohammad and Hinton, Geoffrey},
journal={arXiv preprint arXiv:2006.10029},
year={2020}
}
```
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
from official.common import registry_imports
from official.vision.beta.projects.simclr.configs import simclr
from official.vision.beta.projects.simclr.losses import contrastive_losses
from official.vision.beta.projects.simclr.modeling import simclr_model
from official.vision.beta.projects.simclr.tasks import simclr as simclr_task
# Cifar classification.
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float16'
loss_scale: 'dynamic'
num_gpus: 16
task:
model:
mode: 'pretrain'
input_size: [32, 32, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
backbone_trainable: true
projection_head:
proj_output_dim: 64
num_proj_layers: 2
ft_proj_idx: 1
supervised_head:
num_classes: 10
norm_activation:
use_sync_bn: true
norm_momentum: 0.9
norm_epsilon: 0.00001
loss:
projection_norm: true
temperature: 0.2
evaluation:
top_k: 5
one_hot: true
train_data:
tfds_name: 'cifar10'
tfds_split: 'train'
input_path: ''
is_training: true
global_batch_size: 512
dtype: 'float16'
parser:
mode: 'pretrain'
aug_color_jitter_strength: 0.5
aug_rand_blur: false
decoder:
decode_label: true
validation_data:
tfds_name: 'cifar10'
tfds_split: 'test'
input_path: ''
is_training: false
global_batch_size: 512
dtype: 'float16'
drop_remainder: false
parser:
mode: 'pretrain'
decoder:
decode_label: true
trainer:
train_steps: 48000 # 500 epochs
validation_steps: 18 # NUM_EXAMPLES (10000) // global_batch_size
validation_interval: 96
steps_per_loop: 96 # NUM_EXAMPLES (50000) // global_batch_size
summary_interval: 96
checkpoint_interval: 96
optimizer_config:
optimizer:
type: 'lars'
lars:
momentum: 0.9
weight_decay_rate: 0.000001
exclude_from_weight_decay: ['batch_normalization', 'bias']
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 0.6 # 0.3 × BatchSize / 256
decay_steps: 43200 # train_steps - warmup_steps
warmup:
type: 'linear'
linear:
warmup_steps: 4800 # 10% of total epochs
# ImageNet classification.
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float16'
loss_scale: 'dynamic'
num_gpus: 16
task:
model:
mode: 'finetune'
input_size: [224, 224, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
backbone_trainable: true
projection_head:
proj_output_dim: 128
num_proj_layers: 3
ft_proj_idx: 1
supervised_head:
num_classes: 1001
zero_init: true
norm_activation:
use_sync_bn: false
norm_momentum: 0.9
norm_epsilon: 0.00001
loss:
label_smoothing: 0.0
one_hot: true
evaluation:
top_k: 5
one_hot: true
init_checkpoint: '/placer/prod/scratch/home/tf-model-garden-dev/vision/simclr/r50_1x/2021-03-26'
init_checkpoint_modules: 'backbone_projection'
train_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'train'
input_path: ''
is_training: true
global_batch_size: 1024
dtype: 'float16'
parser:
mode: 'finetune'
validation_data:
tfds_name: 'imagenet2012_subset/10pct'
tfds_split: 'validation'
input_path: ''
is_training: false
global_batch_size: 1024
dtype: 'float16'
drop_remainder: false
parser:
mode: 'finetune'
trainer:
train_steps: 12500 # 100 epochs
validation_steps: 49 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 125
steps_per_loop: 125 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 125
checkpoint_interval: 125
optimizer_config:
optimizer:
type: 'lars'
lars:
momentum: 0.9
weight_decay_rate: 0.0
exclude_from_weight_decay: ['batch_normalization', 'bias']
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 0.04 # 0.01 × BatchSize / 512
decay_steps: 12500 # train_steps
# ImageNet classification.
runtime:
distribution_strategy: 'mirrored'
mixed_precision_dtype: 'float16'
loss_scale: 'dynamic'
num_gpus: 16
task:
model:
mode: 'pretrain'
input_size: [224, 224, 3]
backbone:
type: 'resnet'
resnet:
model_id: 50
backbone_trainable: true
projection_head:
proj_output_dim: 128
num_proj_layers: 3
ft_proj_idx: 0
supervised_head:
num_classes: 1001
norm_activation:
use_sync_bn: true
norm_momentum: 0.9
norm_epsilon: 0.00001
loss:
projection_norm: true
temperature: 0.1
evaluation:
top_k: 5
one_hot: true
train_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*'
is_training: true
global_batch_size: 2048
dtype: 'float16'
parser:
mode: 'pretrain'
decoder:
decode_label: true
validation_data:
input_path: '/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*'
is_training: false
global_batch_size: 2048
dtype: 'float16'
drop_remainder: false
parser:
mode: 'pretrain'
decoder:
decode_label: true
trainer:
train_steps: 187200 # 300 epochs
validation_steps: 24 # NUM_EXAMPLES (50000) // global_batch_size
validation_interval: 624
steps_per_loop: 624 # NUM_EXAMPLES (1281167) // global_batch_size
summary_interval: 624
checkpoint_interval: 624
optimizer_config:
optimizer:
type: 'lars'
lars:
momentum: 0.9
weight_decay_rate: 0.000001
exclude_from_weight_decay: ['batch_normalization', 'bias']
learning_rate:
type: 'cosine'
cosine:
initial_learning_rate: 1.6 # 0.2 * BatchSize / 256
decay_steps: 177840 # train_steps - warmup_steps
warmup:
type: 'linear'
linear:
warmup_steps: 9360 # 5% of total epochs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations."""
import os
from typing import List, Optional
import dataclasses
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.vision.beta.configs import backbones
from official.vision.beta.configs import common
from official.vision.beta.projects.simclr.modeling import simclr_model
@dataclasses.dataclass
class Decoder(hyperparams.Config):
decode_label: bool = True
@dataclasses.dataclass
class Parser(hyperparams.Config):
"""Parser config."""
aug_rand_crop: bool = True
aug_rand_hflip: bool = True
aug_color_distort: bool = True
aug_color_jitter_strength: float = 1.0
aug_color_jitter_impl: str = 'simclrv2' # 'simclrv1' or 'simclrv2'
aug_rand_blur: bool = True
parse_label: bool = True
test_crop: bool = True
mode: str = simclr_model.PRETRAIN
@dataclasses.dataclass
class DataConfig(cfg.DataConfig):
"""Training data config."""
input_path: str = ''
global_batch_size: int = 0
is_training: bool = True
dtype: str = 'float32'
shuffle_buffer_size: int = 10000
cycle_length: int = 10
# simclr specific configs
parser: Parser = Parser()
decoder: Decoder = Decoder()
@dataclasses.dataclass
class ProjectionHead(hyperparams.Config):
proj_output_dim: int = 128
num_proj_layers: int = 3
ft_proj_idx: int = 1 # layer of the projection head to use for fine-tuning.
@dataclasses.dataclass
class SupervisedHead(hyperparams.Config):
num_classes: int = 1001
zero_init: bool = False
@dataclasses.dataclass
class ContrastiveLoss(hyperparams.Config):
projection_norm: bool = True
temperature: float = 0.1
l2_weight_decay: float = 0.0
@dataclasses.dataclass
class ClassificationLosses(hyperparams.Config):
label_smoothing: float = 0.0
one_hot: bool = True
l2_weight_decay: float = 0.0
@dataclasses.dataclass
class Evaluation(hyperparams.Config):
top_k: int = 5
one_hot: bool = True
@dataclasses.dataclass
class SimCLRModel(hyperparams.Config):
"""SimCLR model config."""
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='resnet', resnet=backbones.ResNet())
projection_head: ProjectionHead = ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1)
supervised_head: SupervisedHead = SupervisedHead(num_classes=1001)
norm_activation: common.NormActivation = common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)
mode: str = simclr_model.PRETRAIN
backbone_trainable: bool = True
@dataclasses.dataclass
class SimCLRPretrainTask(cfg.TaskConfig):
"""SimCLR pretraining task config."""
model: SimCLRModel = SimCLRModel(mode=simclr_model.PRETRAIN)
train_data: DataConfig = DataConfig(
parser=Parser(mode=simclr_model.PRETRAIN), is_training=True)
validation_data: DataConfig = DataConfig(
parser=Parser(mode=simclr_model.PRETRAIN), is_training=False)
loss: ContrastiveLoss = ContrastiveLoss()
evaluation: Evaluation = Evaluation()
init_checkpoint: Optional[str] = None
# all or backbone
init_checkpoint_modules: str = 'all'
@dataclasses.dataclass
class SimCLRFinetuneTask(cfg.TaskConfig):
"""SimCLR fine tune task config."""
model: SimCLRModel = SimCLRModel(
mode=simclr_model.FINETUNE,
supervised_head=SupervisedHead(num_classes=1001, zero_init=True))
train_data: DataConfig = DataConfig(
parser=Parser(mode=simclr_model.FINETUNE), is_training=True)
validation_data: DataConfig = DataConfig(
parser=Parser(mode=simclr_model.FINETUNE), is_training=False)
loss: ClassificationLosses = ClassificationLosses()
evaluation: Evaluation = Evaluation()
init_checkpoint: Optional[str] = None
# all, backbone_projection or backbone
init_checkpoint_modules: str = 'backbone_projection'
@exp_factory.register_config_factory('simclr_pretraining')
def simclr_pretraining() -> cfg.ExperimentConfig:
"""Image classification general."""
return cfg.ExperimentConfig(
task=SimCLRPretrainTask(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
@exp_factory.register_config_factory('simclr_finetuning')
def simclr_finetuning() -> cfg.ExperimentConfig:
"""Image classification general."""
return cfg.ExperimentConfig(
task=SimCLRFinetuneTask(),
trainer=cfg.TrainerConfig(),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
IMAGENET_TRAIN_EXAMPLES = 1281167
IMAGENET_VAL_EXAMPLES = 50000
IMAGENET_INPUT_PATH_BASE = 'imagenet-2012-tfrecord'
@exp_factory.register_config_factory('simclr_pretraining_imagenet')
def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
"""Image classification general."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
return cfg.ExperimentConfig(
task=SimCLRPretrainTask(
model=SimCLRModel(
mode=simclr_model.PRETRAIN,
backbone_trainable=True,
input_size=[224, 224, 3],
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1),
supervised_head=SupervisedHead(num_classes=1001),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True)),
loss=ContrastiveLoss(),
evaluation=Evaluation(),
train_data=DataConfig(
parser=Parser(mode=simclr_model.PRETRAIN),
decoder=Decoder(decode_label=True),
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
parser=Parser(mode=simclr_model.PRETRAIN),
decoder=Decoder(decode_label=True),
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size),
),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=500 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'lars',
'lars': {
'momentum': 0.9,
'weight_decay_rate': 0.000001,
'exclude_from_weight_decay': [
'batch_normalization', 'bias']
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
# 0.2 * BatchSize / 256
'initial_learning_rate': 0.2 * train_batch_size / 256,
# train_steps - warmup_steps
'decay_steps': 475 * steps_per_epoch
}
},
'warmup': {
'type': 'linear',
'linear': {
# 5% of total epochs
'warmup_steps': 25 * steps_per_epoch
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
@exp_factory.register_config_factory('simclr_finetuning_imagenet')
def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
"""Image classification general."""
train_batch_size = 1024
eval_batch_size = 1024
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
pretrain_model_base = ''
return cfg.ExperimentConfig(
task=SimCLRFinetuneTask(
model=SimCLRModel(
mode=simclr_model.FINETUNE,
backbone_trainable=True,
input_size=[224, 224, 3],
backbone=backbones.Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)),
projection_head=ProjectionHead(
proj_output_dim=128,
num_proj_layers=3,
ft_proj_idx=1),
supervised_head=SupervisedHead(
num_classes=1001, zero_init=True),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=False)),
loss=ClassificationLosses(),
evaluation=Evaluation(),
train_data=DataConfig(
parser=Parser(mode=simclr_model.FINETUNE),
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
parser=Parser(mode=simclr_model.FINETUNE),
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size),
init_checkpoint=pretrain_model_base,
# all, backbone_projection or backbone
init_checkpoint_modules='backbone_projection'),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=60 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'lars',
'lars': {
'momentum': 0.9,
'weight_decay_rate': 0.0,
'exclude_from_weight_decay': [
'batch_normalization', 'bias']
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
# 0.01 × BatchSize / 512
'initial_learning_rate': 0.01 * train_batch_size / 512,
'decay_steps': 60 * steps_per_epoch
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
from absl.testing import parameterized
import tensorflow as tf
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.vision.beta.projects.simclr.common import registry_imports # pylint: disable=unused-import
from official.vision.beta.projects.simclr.configs import simclr as exp_cfg
class SimCLRConfigTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
'simclr_pretraining_imagenet', 'simclr_finetuning_imagenet')
def test_simclr_configs(self, config_name):
config = exp_factory.get_exp_config(config_name)
self.assertIsInstance(config, cfg.ExperimentConfig)
if config_name == 'simclr_pretrain_imagenet':
self.assertIsInstance(config.task, exp_cfg.SimCLRPretrainTask)
elif config_name == 'simclr_finetuning_imagenet':
self.assertIsInstance(config.task, exp_cfg.SimCLRFinetuneTask)
self.assertIsInstance(config.task.model,
exp_cfg.SimCLRModel)
self.assertIsInstance(config.task.train_data, exp_cfg.DataConfig)
config.task.train_data.is_training = None
with self.assertRaises(KeyError):
config.validate()
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops."""
import functools
import tensorflow as tf
CROP_PROPORTION = 0.875 # Standard for ImageNet.
def random_apply(func, p, x):
"""Randomly apply function func to x with probability p."""
return tf.cond(
tf.less(
tf.random.uniform([], minval=0, maxval=1, dtype=tf.float32),
tf.cast(p, tf.float32)), lambda: func(x), lambda: x)
def random_brightness(image, max_delta, impl='simclrv2'):
"""A multiplicative vs additive change of brightness."""
if impl == 'simclrv2':
factor = tf.random.uniform([], tf.maximum(1.0 - max_delta, 0),
1.0 + max_delta)
image = image * factor
elif impl == 'simclrv1':
image = tf.image.random_brightness(image, max_delta=max_delta)
else:
raise ValueError('Unknown impl {} for random brightness.'.format(impl))
return image
def to_grayscale(image, keep_channels=True):
image = tf.image.rgb_to_grayscale(image)
if keep_channels:
image = tf.tile(image, [1, 1, 3])
return image
def color_jitter_nonrand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0,
impl='simclrv2'):
"""Distorts the color of the image (jittering order is fixed).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x, brightness, contrast, saturation, hue):
"""Apply the i-th transformation."""
if brightness != 0 and i == 0:
x = random_brightness(x, max_delta=brightness, impl=impl)
elif contrast != 0 and i == 1:
x = tf.image.random_contrast(
x, lower=1 - contrast, upper=1 + contrast)
elif saturation != 0 and i == 2:
x = tf.image.random_saturation(
x, lower=1 - saturation, upper=1 + saturation)
elif hue != 0:
x = tf.image.random_hue(x, max_delta=hue)
return x
for i in range(4):
image = apply_transform(i, image, brightness, contrast, saturation, hue)
image = tf.clip_by_value(image, 0., 1.)
return image
def color_jitter_rand(image,
brightness=0,
contrast=0,
saturation=0,
hue=0,
impl='simclrv2'):
"""Distorts the color of the image (jittering order is random).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
with tf.name_scope('distort_color'):
def apply_transform(i, x):
"""Apply the i-th transformation."""
def brightness_foo():
if brightness == 0:
return x
else:
return random_brightness(x, max_delta=brightness, impl=impl)
def contrast_foo():
if contrast == 0:
return x
else:
return tf.image.random_contrast(x, lower=1 - contrast,
upper=1 + contrast)
def saturation_foo():
if saturation == 0:
return x
else:
return tf.image.random_saturation(
x, lower=1 - saturation, upper=1 + saturation)
def hue_foo():
if hue == 0:
return x
else:
return tf.image.random_hue(x, max_delta=hue)
x = tf.cond(tf.less(i, 2),
lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo),
lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo))
return x
perm = tf.random.shuffle(tf.range(4))
for i in range(4):
image = apply_transform(perm[i], image)
image = tf.clip_by_value(image, 0., 1.)
return image
def color_jitter(image, strength, random_order=True, impl='simclrv2'):
"""Distorts the color of the image.
Args:
image: The input image tensor.
strength: the floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
brightness = 0.8 * strength
contrast = 0.8 * strength
saturation = 0.8 * strength
hue = 0.2 * strength
if random_order:
return color_jitter_rand(
image, brightness, contrast, saturation, hue, impl=impl)
else:
return color_jitter_nonrand(
image, brightness, contrast, saturation, hue, impl=impl)
def random_color_jitter(image,
p=1.0,
color_jitter_strength=1.0,
impl='simclrv2'):
"""Perform random color jitter."""
def _transform(image):
color_jitter_t = functools.partial(
color_jitter, strength=color_jitter_strength, impl=impl)
image = random_apply(color_jitter_t, p=0.8, x=image)
return random_apply(to_grayscale, p=0.2, x=image)
return random_apply(_transform, p=p, x=image)
def gaussian_blur(image, kernel_size, sigma, padding='SAME'):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius = tf.cast(kernel_size / 2, dtype=tf.int32)
kernel_size = radius * 2 + 1
x = tf.cast(tf.range(-radius, radius + 1), dtype=tf.float32)
blur_filter = tf.exp(-tf.pow(x, 2.0) /
(2.0 * tf.pow(tf.cast(sigma, dtype=tf.float32), 2.0)))
blur_filter /= tf.reduce_sum(blur_filter)
# One vertical and one horizontal filter.
blur_v = tf.reshape(blur_filter, [kernel_size, 1, 1, 1])
blur_h = tf.reshape(blur_filter, [1, kernel_size, 1, 1])
num_channels = tf.shape(image)[-1]
blur_h = tf.tile(blur_h, [1, 1, num_channels, 1])
blur_v = tf.tile(blur_v, [1, 1, num_channels, 1])
expand_batch_dim = image.shape.ndims == 3
if expand_batch_dim:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image = tf.expand_dims(image, axis=0)
blurred = tf.nn.depthwise_conv2d(
image, blur_h, strides=[1, 1, 1, 1], padding=padding)
blurred = tf.nn.depthwise_conv2d(
blurred, blur_v, strides=[1, 1, 1, 1], padding=padding)
if expand_batch_dim:
blurred = tf.squeeze(blurred, axis=0)
return blurred
def random_blur(image, height, width, p=0.5):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del width
def _transform(image):
sigma = tf.random.uniform([], 0.1, 2.0, dtype=tf.float32)
return gaussian_blur(
image, kernel_size=height // 10, sigma=sigma, padding='SAME')
return random_apply(_transform, p=p, x=image)
def distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100,
scope=None):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: `Tensor` of image data.
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
where each coordinate is [0, 1) and the coordinates are arranged
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
area of the image must contain at least this fraction of any bounding
box supplied.
aspect_ratio_range: An optional list of `float`s. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `float`s. The cropped area of the image
must contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
scope: Optional `str` for name scope.
Returns:
(cropped image `Tensor`, distorted bbox `Tensor`).
"""
with tf.name_scope(scope or 'distorted_bounding_box_crop'):
shape = tf.shape(image)
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
shape,
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, _ = sample_distorted_bounding_box
# Crop the image to the specified bounding box.
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
image = tf.image.crop_to_bounding_box(
image, offset_y, offset_x, target_height, target_width)
return image
def crop_and_resize(image, height, width):
"""Make a random crop and resize it to height `height` and width `width`.
Args:
image: Tensor representing the image.
height: Desired image height.
width: Desired image width.
Returns:
A `height` x `width` x channels Tensor holding a random crop of `image`.
"""
bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
aspect_ratio = width / height
image = distorted_bounding_box_crop(
image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(3. / 4 * aspect_ratio, 4. / 3. * aspect_ratio),
area_range=(0.08, 1.0),
max_attempts=100,
scope=None)
return tf.image.resize([image], [height, width],
method=tf.image.ResizeMethod.BICUBIC)[0]
def random_crop_with_resize(image, height, width, p=1.0):
"""Randomly crop and resize an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: Probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
def _transform(image): # pylint: disable=missing-docstring
image = crop_and_resize(image, height, width)
return image
return random_apply(_transform, p=p, x=image)
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR.
For pre-training:
- Preprocessing:
-> random cropping
-> resize back to the original size
-> random color distortions
-> random Gaussian blur (sequential)
- Each image need to be processed randomly twice
```snippets
if train_mode == 'pretrain':
xs = []
for _ in range(2): # Two transformations
xs.append(preprocess_fn_pretrain(image))
image = tf.concat(xs, -1)
else:
image = preprocess_fn_finetune(image)
```
For fine-tuning:
typical image classification input
"""
from typing import List
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.projects.simclr.dataloaders import preprocess_ops as simclr_preprocess_ops
from official.vision.beta.projects.simclr.modeling import simclr_model
class Decoder(decoder.Decoder):
"""A tf.Example decoder for classification task."""
def __init__(self, decode_label=True):
self._decode_label = decode_label
self._keys_to_features = {
'image/encoded': tf.io.FixedLenFeature((), tf.string, default_value=''),
}
if self._decode_label:
self._keys_to_features.update({
'image/class/label': (
tf.io.FixedLenFeature((), tf.int64, default_value=-1))
})
def decode(self, serialized_example):
return tf.io.parse_single_example(
serialized_example, self._keys_to_features)
class TFDSDecoder(decoder.Decoder):
"""A TFDS decoder for classification task."""
def __init__(self, decode_label=True):
self._decode_label = decode_label
def decode(self, serialized_example):
sample_dict = {
'image/encoded': tf.io.encode_jpeg(
serialized_example['image'], quality=100),
}
if self._decode_label:
sample_dict.update({
'image/class/label': serialized_example['label'],
})
return sample_dict
class Parser(parser.Parser):
"""Parser for SimCLR training."""
def __init__(self,
output_size: List[int],
aug_rand_crop: bool = True,
aug_rand_hflip: bool = True,
aug_color_distort: bool = True,
aug_color_jitter_strength: float = 1.0,
aug_color_jitter_impl: str = 'simclrv2',
aug_rand_blur: bool = True,
parse_label: bool = True,
test_crop: bool = True,
mode: str = simclr_model.PRETRAIN,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
aug_rand_crop: `bool`, if Ture, augment training with random cropping.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_color_distort: `bool`, if True augment training with color distortion.
aug_color_jitter_strength: `float`, the floating number for the strength
of the color augmentation
aug_color_jitter_impl: `str`, 'simclrv1' or 'simclrv2'. Define whether
to use simclrv1 or simclrv2's version of random brightness.
aug_rand_blur: `bool`, if True, augment training with random blur.
parse_label: `bool`, if True, parse label together with image.
test_crop: `bool`, if True, augment eval with center cropping.
mode: `str`, 'pretain' or 'finetune'. Define training mode.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
self._output_size = output_size
self._aug_rand_crop = aug_rand_crop
self._aug_rand_hflip = aug_rand_hflip
self._aug_color_distort = aug_color_distort
self._aug_color_jitter_strength = aug_color_jitter_strength
self._aug_color_jitter_impl = aug_color_jitter_impl
self._aug_rand_blur = aug_rand_blur
self._parse_label = parse_label
self._mode = mode
self._test_crop = test_crop
if max(self._output_size[0], self._output_size[1]) <= 32:
self._test_crop = False
if dtype == 'float32':
self._dtype = tf.float32
elif dtype == 'float16':
self._dtype = tf.float16
elif dtype == 'bfloat16':
self._dtype = tf.bfloat16
else:
raise ValueError('dtype {!r} is not supported!'.format(dtype))
def _parse_one_train_image(self, image_bytes):
image = tf.image.decode_jpeg(image_bytes, channels=3)
# This line convert the image to float 0.0 - 1.0
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
if self._aug_rand_crop:
image = simclr_preprocess_ops.random_crop_with_resize(
image, self._output_size[0], self._output_size[1])
if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image)
if self._aug_color_distort and self._mode == simclr_model.PRETRAIN:
image = simclr_preprocess_ops.random_color_jitter(
image=image,
color_jitter_strength=self._aug_color_jitter_strength,
impl=self._aug_color_jitter_impl)
if self._aug_rand_blur and self._mode == simclr_model.PRETRAIN:
image = simclr_preprocess_ops.random_blur(
image, self._output_size[0], self._output_size[1])
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])
image = tf.clip_by_value(image, 0., 1.)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
return image
def _parse_train_data(self, decoded_tensors):
"""Parses data for training."""
image_bytes = decoded_tensors['image/encoded']
if self._mode == simclr_model.FINETUNE:
image = self._parse_one_train_image(image_bytes)
elif self._mode == simclr_model.PRETRAIN:
# Transform each example twice using a combination of
# simple augmentations, resulting in 2N data points
xs = []
for _ in range(2):
xs.append(self._parse_one_train_image(image_bytes))
image = tf.concat(xs, -1)
else:
raise ValueError('The mode {} is not supported by the Parser.'
.format(self._mode))
if self._parse_label:
label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
return image, label
return image
def _parse_eval_data(self, decoded_tensors):
"""Parses data for evaluation."""
image_bytes = decoded_tensors['image/encoded']
image_shape = tf.image.extract_jpeg_shape(image_bytes)
if self._test_crop:
image = preprocess_ops.center_crop_image_v2(image_bytes, image_shape)
else:
image = tf.image.decode_jpeg(image_bytes, channels=3)
# This line convert the image to float 0.0 - 1.0
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image = tf.reshape(image, [self._output_size[0], self._output_size[1], 3])
image = tf.clip_by_value(image, 0., 1.)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
if self._parse_label:
label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
return image, label
return image
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
from typing import Text, Optional
import tensorflow as tf
from official.vision.beta.projects.simclr.modeling.layers import nn_blocks
regularizers = tf.keras.regularizers
layers = tf.keras.layers
@tf.keras.utils.register_keras_serializable(package='simclr')
class ProjectionHead(tf.keras.layers.Layer):
"""Projection head."""
def __init__(
self,
num_proj_layers: int = 3,
proj_output_dim: Optional[int] = None,
ft_proj_idx: int = 0,
kernel_initializer: Text = 'VarianceScaling',
kernel_regularizer: Optional[regularizers.Regularizer] = None,
bias_regularizer: Optional[regularizers.Regularizer] = None,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
**kwargs):
"""The projection head used during pretraining of SimCLR.
Args:
num_proj_layers: `int` number of Dense layers used.
proj_output_dim: `int` output dimension of projection head, i.e., output
dimension of the final layer.
ft_proj_idx: `int` index of layer to use during fine-tuning. 0 means no
projection head during fine tuning, -1 means the final layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super(ProjectionHead, self).__init__(**kwargs)
assert proj_output_dim is not None or num_proj_layers == 0
assert ft_proj_idx <= num_proj_layers, (num_proj_layers, ft_proj_idx)
self._proj_output_dim = proj_output_dim
self._num_proj_layers = num_proj_layers
self._ft_proj_idx = ft_proj_idx
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._layers = []
def get_config(self):
config = {
'proj_output_dim': self._proj_output_dim,
'num_proj_layers': self._num_proj_layers,
'ft_proj_idx': self._ft_proj_idx,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon
}
base_config = super(ProjectionHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._layers = []
if self._num_proj_layers > 0:
intermediate_dim = int(input_shape[-1])
for j in range(self._num_proj_layers):
if j != self._num_proj_layers - 1:
# for the middle layers, use bias and relu for the output.
layer = nn_blocks.DenseBN(
output_dim=intermediate_dim,
use_bias=True,
use_normalization=True,
activation='relu',
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
name='nl_%d' % j)
else:
# for the final layer, neither bias nor relu is used.
layer = nn_blocks.DenseBN(
output_dim=self._proj_output_dim,
use_bias=False,
use_normalization=True,
activation=None,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
use_sync_bn=self._use_sync_bn,
norm_momentum=self._norm_momentum,
norm_epsilon=self._norm_epsilon,
name='nl_%d' % j)
self._layers.append(layer)
super(ProjectionHead, self).build(input_shape)
def call(self, inputs, training=None):
hiddens_list = [tf.identity(inputs, 'proj_head_input')]
if self._num_proj_layers == 0:
proj_head_output = inputs
proj_finetune_output = inputs
else:
for j in range(self._num_proj_layers):
hiddens = self._layers[j](hiddens_list[-1], training)
hiddens_list.append(hiddens)
proj_head_output = tf.identity(
hiddens_list[-1], 'proj_head_output')
proj_finetune_output = tf.identity(
hiddens_list[self._ft_proj_idx], 'proj_finetune_output')
# The first element is the output of the projection head.
# The second element is the input of the finetune head.
return proj_head_output, proj_finetune_output
@tf.keras.utils.register_keras_serializable(package='simclr')
class ClassificationHead(tf.keras.layers.Layer):
"""Classification Head."""
def __init__(
self,
num_classes: int,
kernel_initializer: Text = 'random_uniform',
kernel_regularizer: Optional[regularizers.Regularizer] = None,
bias_regularizer: Optional[regularizers.Regularizer] = None,
name: Text = 'head_supervised',
**kwargs):
"""The classification head used during pretraining or fine tuning.
Args:
num_classes: `int` size of the output dimension or number of classes
for classification task.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
name: `str`, name of the layer.
**kwargs: keyword arguments to be passed.
"""
super(ClassificationHead, self).__init__(name=name, **kwargs)
self._num_classes = num_classes
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._name = name
def get_config(self):
config = {
'num_classes': self._num_classes,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
}
base_config = super(ClassificationHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._dense0 = layers.Dense(
units=self._num_classes,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activation=None)
super(ClassificationHead, self).build(input_shape)
def call(self, inputs, training=None):
inputs = self._dense0(inputs)
return inputs
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.simclr.heads import simclr_head
class ProjectionHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
(0, None),
(1, 128),
(2, 128),
)
def test_head_creation(self, num_proj_layers, proj_output_dim):
test_layer = simclr_head.ProjectionHead(
num_proj_layers=num_proj_layers,
proj_output_dim=proj_output_dim)
input_dim = 64
x = tf.keras.Input(shape=(input_dim,))
proj_head_output, proj_finetune_output = test_layer(x)
proj_head_output_dim = input_dim
if num_proj_layers > 0:
proj_head_output_dim = proj_output_dim
self.assertAllEqual(proj_head_output.shape.as_list(),
[None, proj_head_output_dim])
if num_proj_layers > 0:
proj_finetune_output_dim = input_dim
self.assertAllEqual(proj_finetune_output.shape.as_list(),
[None, proj_finetune_output_dim])
@parameterized.parameters(
(0, None, 0),
(1, 128, 0),
(2, 128, 1),
(2, 128, 2),
)
def test_outputs(self, num_proj_layers, proj_output_dim, ft_proj_idx):
test_layer = simclr_head.ProjectionHead(
num_proj_layers=num_proj_layers,
proj_output_dim=proj_output_dim,
ft_proj_idx=ft_proj_idx
)
input_dim = 64
batch_size = 2
inputs = np.random.rand(batch_size, input_dim)
proj_head_output, proj_finetune_output = test_layer(inputs)
if num_proj_layers == 0:
self.assertAllClose(inputs, proj_head_output)
self.assertAllClose(inputs, proj_finetune_output)
else:
self.assertAllEqual(proj_head_output.shape.as_list(),
[batch_size, proj_output_dim])
if ft_proj_idx == 0:
self.assertAllClose(inputs, proj_finetune_output)
elif ft_proj_idx < num_proj_layers:
self.assertAllEqual(proj_finetune_output.shape.as_list(),
[batch_size, input_dim])
else:
self.assertAllEqual(proj_finetune_output.shape.as_list(),
[batch_size, proj_output_dim])
class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
10, 20
)
def test_head_creation(self, num_classes):
test_layer = simclr_head.ClassificationHead(num_classes=num_classes)
input_dim = 64
x = tf.keras.Input(shape=(input_dim,))
out_x = test_layer(x)
self.assertAllEqual(out_x.shape.as_list(),
[None, num_classes])
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions."""
import functools
import tensorflow as tf
LARGE_NUM = 1e9
def cross_replica_concat(tensor: tf.Tensor, num_replicas: int) -> tf.Tensor:
"""Reduce a concatenation of the `tensor` across multiple replicas.
Args:
tensor: `tf.Tensor` to concatenate.
num_replicas: `int` number of replicas.
Returns:
Tensor of the same rank as `tensor` with first dimension `num_replicas`
times larger.
"""
if num_replicas <= 1:
return tensor
replica_context = tf.distribute.get_replica_context()
with tf.name_scope('cross_replica_concat'):
# This creates a tensor that is like the input tensor but has an added
# replica dimension as the outermost dimension. On each replica it will
# contain the local values and zeros for all other values that need to be
# fetched from other replicas.
ext_tensor = tf.scatter_nd(
indices=[[replica_context.replica_id_in_sync_group]],
updates=[tensor],
shape=tf.concat([[num_replicas], tf.shape(tensor)], axis=0))
# As every value is only present on one replica and 0 in all others, adding
# them all together will result in the full tensor on all replicas.
ext_tensor = replica_context.all_reduce(tf.distribute.ReduceOp.SUM,
ext_tensor)
# Flatten the replica dimension.
# The first dimension size will be: tensor.shape[0] * num_replicas
# Using [-1] trick to support also scalar input.
return tf.reshape(ext_tensor, [-1] + ext_tensor.shape.as_list()[2:])
@tf.keras.utils.register_keras_serializable(package='simclr')
class ContrastiveLoss(object):
"""Contrastive training loss function."""
def __init__(self, projection_norm: bool = True, temperature: float = 1.0):
"""Initializes `ContrastiveLoss`.
Args:
projection_norm: whether or not to use normalization on the hidden vector.
temperature: a `floating` number for temperature scaling.
"""
self._projection_norm = projection_norm
self._temperature = temperature
def __call__(self, projection1: tf.Tensor, projection2: tf.Tensor):
"""Compute the contrastive loss for contrastive learning.
Note that projection2 is generated with the same batch (same order) of raw
images, but with different augmentation. More specifically:
image[i] -> random augmentation 1 -> projection -> projection1[i]
image[i] -> random augmentation 2 -> projection -> projection2[i]
Args:
projection1: projection vector of shape (bsz, dim).
projection2: projection vector of shape (bsz, dim).
Returns:
A loss scalar.
The logits for contrastive prediction task.
The labels for contrastive prediction task.
"""
# Get (normalized) hidden1 and hidden2.
if self._projection_norm:
projection1 = tf.math.l2_normalize(projection1, -1)
projection2 = tf.math.l2_normalize(projection2, -1)
batch_size = tf.shape(projection1)[0]
p1_local, p2_local = projection1, projection2
# Gather projection1/projection2 across replicas and create local labels.
num_replicas_in_sync = tf.distribute.get_strategy().num_replicas_in_sync
if num_replicas_in_sync > 1:
p1_global = cross_replica_concat(p1_local, num_replicas_in_sync)
p2_global = cross_replica_concat(p2_local, num_replicas_in_sync)
global_batch_size = tf.shape(p1_global)[0]
replica_context = tf.distribute.get_replica_context()
replica_id = tf.cast(
tf.cast(replica_context.replica_id_in_sync_group, tf.uint32),
tf.int32)
labels_idx = tf.range(batch_size) + replica_id * batch_size
labels = tf.one_hot(labels_idx, global_batch_size * 2)
masks = tf.one_hot(labels_idx, global_batch_size)
else:
p1_global = p1_local
p2_global = p2_local
labels = tf.one_hot(tf.range(batch_size), batch_size * 2)
masks = tf.one_hot(tf.range(batch_size), batch_size)
tb_matmul = functools.partial(tf.matmul, transpose_b=True)
logits_aa = tb_matmul(p1_local, p1_global) / self._temperature
logits_aa = logits_aa - masks * LARGE_NUM
logits_bb = tb_matmul(p2_local, p2_global) / self._temperature
logits_bb = logits_bb - masks * LARGE_NUM
logits_ab = tb_matmul(p1_local, p2_global) / self._temperature
logits_ba = tb_matmul(p2_local, p1_global) / self._temperature
loss_a_local = tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ab, logits_aa], 1))
loss_b_local = tf.nn.softmax_cross_entropy_with_logits(
labels, tf.concat([logits_ba, logits_bb], 1))
loss_local = tf.reduce_mean(loss_a_local + loss_b_local)
return loss_local, (logits_ab, labels)
def get_config(self):
config = {
'projection_norm': self._projection_norm,
'temperature': self._temperature,
}
return config
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.projects.simclr.losses import contrastive_losses
class ContrastiveLossesTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(1.0, 0.5)
def test_contrastive_loss_computation(self, temperature):
batch_size = 2
project_dim = 16
projection_norm = False
p_1_arr = np.random.rand(batch_size, project_dim)
p_1 = tf.constant(p_1_arr, dtype=tf.float32)
p_2_arr = np.random.rand(batch_size, project_dim)
p_2 = tf.constant(p_2_arr, dtype=tf.float32)
losses_obj = contrastive_losses.ContrastiveLoss(
projection_norm=projection_norm,
temperature=temperature)
comp_contrastive_loss = losses_obj(
projection1=p_1,
projection2=p_2)
def _exp_sim(p1, p2):
return np.exp(np.matmul(p1, p2) / temperature)
l11 = - np.log(
_exp_sim(p_1_arr[0], p_2_arr[0]) /
(_exp_sim(p_1_arr[0], p_1_arr[1])
+ _exp_sim(p_1_arr[0], p_2_arr[1])
+ _exp_sim(p_1_arr[0], p_2_arr[0]))
) - np.log(
_exp_sim(p_1_arr[0], p_2_arr[0]) /
(_exp_sim(p_2_arr[0], p_2_arr[1])
+ _exp_sim(p_2_arr[0], p_1_arr[1])
+ _exp_sim(p_1_arr[0], p_2_arr[0]))
)
l22 = - np.log(
_exp_sim(p_1_arr[1], p_2_arr[1]) /
(_exp_sim(p_1_arr[1], p_1_arr[0])
+ _exp_sim(p_1_arr[1], p_2_arr[0])
+ _exp_sim(p_1_arr[1], p_2_arr[1]))
) - np.log(
_exp_sim(p_1_arr[1], p_2_arr[1]) /
(_exp_sim(p_2_arr[1], p_2_arr[0])
+ _exp_sim(p_2_arr[1], p_1_arr[0])
+ _exp_sim(p_1_arr[1], p_2_arr[1]))
)
exp_contrastive_loss = (l11 + l22) / 2.0
self.assertAlmostEqual(comp_contrastive_loss[0].numpy(),
exp_contrastive_loss, places=5)
if __name__ == '__main__':
tf.test.main()
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for simclr neural networks."""
from typing import Text, Optional
import tensorflow as tf
from official.modeling import tf_utils
regularizers = tf.keras.regularizers
@tf.keras.utils.register_keras_serializable(package='simclr')
class DenseBN(tf.keras.layers.Layer):
"""Modified Dense layer to help build simclr system.
The layer is a standards combination of Dense, BatchNorm and Activation.
"""
def __init__(
self,
output_dim: int,
use_bias: bool = True,
use_normalization: bool = False,
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
activation: Optional[Text] = 'relu',
kernel_initializer: Text = 'VarianceScaling',
kernel_regularizer: Optional[regularizers.Regularizer] = None,
bias_regularizer: Optional[regularizers.Regularizer] = None,
name='linear_layer',
**kwargs):
"""Customized Dense layer.
Args:
output_dim: `int` size of output dimension.
use_bias: if True, use biase in the dense layer.
use_normalization: if True, use batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
activation: `str` name of the activation function.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
name: `str`, name of the layer.
**kwargs: keyword arguments to be passed.
"""
# Note: use_bias is ignored for the dense layer when use_bn=True.
# However, it is still used for batch norm.
super(DenseBN, self).__init__(**kwargs)
self._output_dim = output_dim
self._use_bias = use_bias
self._use_normalization = use_normalization
self._use_sync_bn = use_sync_bn
self._norm_momentum = norm_momentum
self._norm_epsilon = norm_epsilon
self._activation = activation
self._kernel_initializer = kernel_initializer
self._kernel_regularizer = kernel_regularizer
self._bias_regularizer = bias_regularizer
self._name = name
if use_sync_bn:
self._norm = tf.keras.layers.experimental.SyncBatchNormalization
else:
self._norm = tf.keras.layers.BatchNormalization
if tf.keras.backend.image_data_format() == 'channels_last':
self._bn_axis = -1
else:
self._bn_axis = 1
if activation:
self._activation_fn = tf_utils.get_activation(activation)
else:
self._activation_fn = None
def get_config(self):
config = {
'output_dim': self._output_dim,
'use_bias': self._use_bias,
'activation': self._activation,
'use_sync_bn': self._use_sync_bn,
'use_normalization': self._use_normalization,
'norm_momentum': self._norm_momentum,
'norm_epsilon': self._norm_epsilon,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
}
base_config = super(DenseBN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
self._dense0 = tf.keras.layers.Dense(
self._output_dim,
kernel_initializer=self._kernel_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
use_bias=self._use_bias and not self._use_normalization)
if self._use_normalization:
self._norm0 = self._norm(
axis=self._bn_axis,
momentum=self._norm_momentum,
epsilon=self._norm_epsilon,
center=self._use_bias,
scale=True)
super(DenseBN, self).build(input_shape)
def call(self, inputs, training=None):
assert inputs.shape.ndims == 2, inputs.shape
x = self._dense0(inputs)
if self._use_normalization:
x = self._norm0(x)
if self._activation:
x = self._activation_fn(x)
return x
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.projects.simclr.modeling.layers import nn_blocks
class DenseBNTest(tf.test.TestCase, parameterized.TestCase):
@parameterized.parameters(
(64, True, True),
(64, True, False),
(64, False, True),
)
def test_pass_through(self, output_dim, use_bias, use_normalization):
test_layer = nn_blocks.DenseBN(
output_dim=output_dim,
use_bias=use_bias,
use_normalization=use_normalization
)
x = tf.keras.Input(shape=(64,))
out_x = test_layer(x)
self.assertAllEqual(out_x.shape.as_list(), [None, output_dim])
# kernel of the dense layer
train_var_len = 1
if use_normalization:
if use_bias:
# batch norm introduce two trainable variables
train_var_len += 2
else:
# center is set to False if not use bias
train_var_len += 1
else:
if use_bias:
# bias of dense layer
train_var_len += 1
self.assertLen(test_layer.trainable_variables, train_var_len)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment