Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aba78478
Commit
aba78478
authored
Mar 29, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 365713370
parent
f3f3ec34
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2372 additions
and
2 deletions
+2372
-2
official/modeling/optimization/configs/optimization_config.py
...cial/modeling/optimization/configs/optimization_config.py
+2
-0
official/modeling/optimization/configs/optimizer_config.py
official/modeling/optimization/configs/optimizer_config.py
+35
-0
official/modeling/optimization/lars_optimizer.py
official/modeling/optimization/lars_optimizer.py
+186
-0
official/modeling/optimization/optimizer_factory.py
official/modeling/optimization/optimizer_factory.py
+3
-1
official/modeling/optimization/optimizer_factory_test.py
official/modeling/optimization/optimizer_factory_test.py
+3
-1
official/vision/beta/projects/simclr/README.md
official/vision/beta/projects/simclr/README.md
+78
-0
official/vision/beta/projects/simclr/common/registry_imports.py
...al/vision/beta/projects/simclr/common/registry_imports.py
+36
-0
official/vision/beta/projects/simclr/configs/experiments/cifar_simclr_pretrain.yaml
...cts/simclr/configs/experiments/cifar_simclr_pretrain.yaml
+79
-0
official/vision/beta/projects/simclr/configs/experiments/imagenet_simclr_finetune_gpu.yaml
...clr/configs/experiments/imagenet_simclr_finetune_gpu.yaml
+72
-0
official/vision/beta/projects/simclr/configs/experiments/imagenet_simclr_pretrain_gpu.yaml
...clr/configs/experiments/imagenet_simclr_pretrain_gpu.yaml
+73
-0
official/vision/beta/projects/simclr/configs/simclr.py
official/vision/beta/projects/simclr/configs/simclr.py
+332
-0
official/vision/beta/projects/simclr/configs/simclr_test.py
official/vision/beta/projects/simclr/configs/simclr_test.py
+62
-0
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
...vision/beta/projects/simclr/dataloaders/preprocess_ops.py
+363
-0
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
...l/vision/beta/projects/simclr/dataloaders/simclr_input.py
+242
-0
official/vision/beta/projects/simclr/heads/simclr_head.py
official/vision/beta/projects/simclr/heads/simclr_head.py
+215
-0
official/vision/beta/projects/simclr/heads/simclr_head_test.py
...ial/vision/beta/projects/simclr/heads/simclr_head_test.py
+117
-0
official/vision/beta/projects/simclr/losses/contrastive_losses.py
.../vision/beta/projects/simclr/losses/contrastive_losses.py
+157
-0
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
...on/beta/projects/simclr/losses/contrastive_losses_test.py
+93
-0
official/vision/beta/projects/simclr/modeling/layers/nn_blocks.py
.../vision/beta/projects/simclr/modeling/layers/nn_blocks.py
+150
-0
official/vision/beta/projects/simclr/modeling/layers/nn_blocks_test.py
...on/beta/projects/simclr/modeling/layers/nn_blocks_test.py
+74
-0
No files found.
official/modeling/optimization/configs/optimization_config.py
View file @
aba78478
...
@@ -39,6 +39,7 @@ class OptimizerConfig(oneof.OneOfConfig):
...
@@ -39,6 +39,7 @@ class OptimizerConfig(oneof.OneOfConfig):
adamw: adam with weight decay.
adamw: adam with weight decay.
lamb: lamb optimizer.
lamb: lamb optimizer.
rmsprop: rmsprop optimizer.
rmsprop: rmsprop optimizer.
lars: lars optimizer.
"""
"""
type
:
Optional
[
str
]
=
None
type
:
Optional
[
str
]
=
None
sgd
:
opt_cfg
.
SGDConfig
=
opt_cfg
.
SGDConfig
()
sgd
:
opt_cfg
.
SGDConfig
=
opt_cfg
.
SGDConfig
()
...
@@ -46,6 +47,7 @@ class OptimizerConfig(oneof.OneOfConfig):
...
@@ -46,6 +47,7 @@ class OptimizerConfig(oneof.OneOfConfig):
adamw
:
opt_cfg
.
AdamWeightDecayConfig
=
opt_cfg
.
AdamWeightDecayConfig
()
adamw
:
opt_cfg
.
AdamWeightDecayConfig
=
opt_cfg
.
AdamWeightDecayConfig
()
lamb
:
opt_cfg
.
LAMBConfig
=
opt_cfg
.
LAMBConfig
()
lamb
:
opt_cfg
.
LAMBConfig
=
opt_cfg
.
LAMBConfig
()
rmsprop
:
opt_cfg
.
RMSPropConfig
=
opt_cfg
.
RMSPropConfig
()
rmsprop
:
opt_cfg
.
RMSPropConfig
=
opt_cfg
.
RMSPropConfig
()
lars
:
opt_cfg
.
LARSConfig
=
opt_cfg
.
LARSConfig
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/modeling/optimization/configs/optimizer_config.py
View file @
aba78478
...
@@ -170,3 +170,38 @@ class EMAConfig(BaseOptimizerConfig):
...
@@ -170,3 +170,38 @@ class EMAConfig(BaseOptimizerConfig):
average_decay
:
float
=
0.99
average_decay
:
float
=
0.99
start_step
:
int
=
0
start_step
:
int
=
0
dynamic_decay
:
bool
=
True
dynamic_decay
:
bool
=
True
@
dataclasses
.
dataclass
class
LARSConfig
(
BaseOptimizerConfig
):
"""Layer-wise adaptive rate scaling config.
Attributes:
name: 'str', name of the optimizer.
momentum: `float` hyperparameter >= 0 that accelerates gradient descent
in the relevant direction and dampens oscillations. Defaults to 0.9.
eeta: `float` LARS coefficient as used in the paper. Default set to LARS
coefficient from the paper. (eeta / weight_decay) determines the
highest scaling factor in LARS..
weight_decay_rate: `float` for weight decay.
nesterov: 'boolean' for whether to use nesterov momentum.
classic_momentum: `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momentum update in
classic momentum, but after momentum for popular momentum.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['batch_normalization', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
"""
name
:
str
=
"LARS"
momentum
:
float
=
0.9
eeta
:
float
=
0.001
weight_decay_rate
:
float
=
0.0
nesterov
:
bool
=
False
classic_momentum
:
bool
=
True
exclude_from_weight_decay
:
Optional
[
List
[
str
]]
=
None
exclude_from_layer_adaptation
:
Optional
[
List
[
str
]]
=
None
official/modeling/optimization/lars_optimizer.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Layer-wise adaptive rate scaling optimizer."""
import
re
from
typing
import
Text
,
List
,
Optional
import
tensorflow
as
tf
# pylint: disable=protected-access
class
LARS
(
tf
.
keras
.
optimizers
.
Optimizer
):
"""Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
"""
def
__init__
(
self
,
learning_rate
:
float
=
0.01
,
momentum
:
float
=
0.9
,
weight_decay_rate
:
float
=
0.0
,
eeta
:
float
=
0.001
,
nesterov
:
bool
=
False
,
classic_momentum
:
bool
=
True
,
exclude_from_weight_decay
:
Optional
[
List
[
Text
]]
=
None
,
exclude_from_layer_adaptation
:
Optional
[
List
[
Text
]]
=
None
,
name
:
Text
=
"LARS"
,
**
kwargs
):
"""Constructs a LARSOptimizer.
Args:
learning_rate: `float` for learning rate. Defaults to 0.01.
momentum: `float` hyperparameter >= 0 that accelerates gradient descent
in the relevant direction and dampens oscillations. Defaults to 0.9.
weight_decay_rate: `float` for weight decay.
eeta: `float` LARS coefficient as used in the paper. Default set to LARS
coefficient from the paper. (eeta / weight_decay) determines the
highest scaling factor in LARS..
nesterov: 'boolean' for whether to use nesterov momentum.
classic_momentum: `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momentum update in
classic momentum, but after momentum for popular momentum.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['batch_normalization', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
name: `Text` as optional name for the operations created when applying
gradients. Defaults to "LARS".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
`decay`}. `clipnorm` is clip gradients by norm; `clipvalue` is clip
gradients by value, `decay` is included for backward compatibility to
allow time inverse decay of learning rate. `lr` is included for
backward compatibility, recommended to use `learning_rate` instead.
"""
super
(
LARS
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_set_hyper
(
"learning_rate"
,
learning_rate
)
self
.
_set_hyper
(
"decay"
,
self
.
_initial_decay
)
self
.
momentum
=
momentum
self
.
weight_decay_rate
=
weight_decay_rate
self
.
eeta
=
eeta
self
.
nesterov
=
nesterov
self
.
classic_momentum
=
classic_momentum
self
.
exclude_from_weight_decay
=
exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
if
exclude_from_layer_adaptation
:
self
.
exclude_from_layer_adaptation
=
exclude_from_layer_adaptation
else
:
self
.
exclude_from_layer_adaptation
=
exclude_from_weight_decay
def
_create_slots
(
self
,
var_list
):
for
v
in
var_list
:
self
.
add_slot
(
v
,
"momentum"
)
def
_resource_apply_dense
(
self
,
grad
,
param
,
apply_state
=
None
):
if
grad
is
None
or
param
is
None
:
return
tf
.
no_op
()
var_device
,
var_dtype
=
param
.
device
,
param
.
dtype
.
base_dtype
coefficients
=
((
apply_state
or
{}).
get
((
var_device
,
var_dtype
))
or
self
.
_fallback_apply_state
(
var_device
,
var_dtype
))
learning_rate
=
coefficients
[
"lr_t"
]
param_name
=
param
.
name
v
=
self
.
get_slot
(
param
,
"momentum"
)
if
self
.
_use_weight_decay
(
param_name
):
grad
+=
self
.
weight_decay_rate
*
param
if
self
.
classic_momentum
:
trust_ratio
=
1.0
if
self
.
_do_layer_adaptation
(
param_name
):
w_norm
=
tf
.
norm
(
param
,
ord
=
2
)
g_norm
=
tf
.
norm
(
grad
,
ord
=
2
)
trust_ratio
=
tf
.
where
(
tf
.
greater
(
w_norm
,
0
),
tf
.
where
(
tf
.
greater
(
g_norm
,
0
),
(
self
.
eeta
*
w_norm
/
g_norm
),
1.0
),
1.0
)
scaled_lr
=
learning_rate
*
trust_ratio
next_v
=
tf
.
multiply
(
self
.
momentum
,
v
)
+
scaled_lr
*
grad
if
self
.
nesterov
:
update
=
tf
.
multiply
(
self
.
momentum
,
next_v
)
+
scaled_lr
*
grad
else
:
update
=
next_v
next_param
=
param
-
update
else
:
next_v
=
tf
.
multiply
(
self
.
momentum
,
v
)
+
grad
if
self
.
nesterov
:
update
=
tf
.
multiply
(
self
.
momentum
,
next_v
)
+
grad
else
:
update
=
next_v
trust_ratio
=
1.0
if
self
.
_do_layer_adaptation
(
param_name
):
w_norm
=
tf
.
norm
(
param
,
ord
=
2
)
v_norm
=
tf
.
norm
(
update
,
ord
=
2
)
trust_ratio
=
tf
.
where
(
tf
.
greater
(
w_norm
,
0
),
tf
.
where
(
tf
.
greater
(
v_norm
,
0
),
(
self
.
eeta
*
w_norm
/
v_norm
),
1.0
),
1.0
)
scaled_lr
=
trust_ratio
*
learning_rate
next_param
=
param
-
scaled_lr
*
update
return
tf
.
group
(
*
[
param
.
assign
(
next_param
,
use_locking
=
False
),
v
.
assign
(
next_v
,
use_locking
=
False
)
])
def
_resource_apply_sparse
(
self
,
grad
,
handle
,
indices
,
apply_state
):
raise
NotImplementedError
(
"Applying sparse gradients is not implemented."
)
def
_use_weight_decay
(
self
,
param_name
):
"""Whether to use L2 weight decay for `param_name`."""
if
not
self
.
weight_decay_rate
:
return
False
if
self
.
exclude_from_weight_decay
:
for
r
in
self
.
exclude_from_weight_decay
:
if
re
.
search
(
r
,
param_name
)
is
not
None
:
return
False
return
True
def
_do_layer_adaptation
(
self
,
param_name
):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if
self
.
exclude_from_layer_adaptation
:
for
r
in
self
.
exclude_from_layer_adaptation
:
if
re
.
search
(
r
,
param_name
)
is
not
None
:
return
False
return
True
def
get_config
(
self
):
config
=
super
(
LARS
,
self
).
get_config
()
config
.
update
({
"learning_rate"
:
self
.
_serialize_hyperparameter
(
"learning_rate"
),
"decay"
:
self
.
_serialize_hyperparameter
(
"decay"
),
"momentum"
:
self
.
momentum
,
"classic_momentum"
:
self
.
classic_momentum
,
"weight_decay_rate"
:
self
.
weight_decay_rate
,
"eeta"
:
self
.
eeta
,
"nesterov"
:
self
.
nesterov
,
})
return
config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/modeling/optimization/optimizer_factory.py
View file @
aba78478
...
@@ -21,6 +21,7 @@ import tensorflow as tf
...
@@ -21,6 +21,7 @@ import tensorflow as tf
import
tensorflow_addons.optimizers
as
tfa_optimizers
import
tensorflow_addons.optimizers
as
tfa_optimizers
from
official.modeling.optimization
import
ema_optimizer
from
official.modeling.optimization
import
ema_optimizer
from
official.modeling.optimization
import
lars_optimizer
from
official.modeling.optimization
import
lr_schedule
from
official.modeling.optimization
import
lr_schedule
from
official.modeling.optimization.configs
import
optimization_config
as
opt_cfg
from
official.modeling.optimization.configs
import
optimization_config
as
opt_cfg
from
official.nlp
import
optimization
as
nlp_optimization
from
official.nlp
import
optimization
as
nlp_optimization
...
@@ -30,7 +31,8 @@ OPTIMIZERS_CLS = {
...
@@ -30,7 +31,8 @@ OPTIMIZERS_CLS = {
'adam'
:
tf
.
keras
.
optimizers
.
Adam
,
'adam'
:
tf
.
keras
.
optimizers
.
Adam
,
'adamw'
:
nlp_optimization
.
AdamWeightDecay
,
'adamw'
:
nlp_optimization
.
AdamWeightDecay
,
'lamb'
:
tfa_optimizers
.
LAMB
,
'lamb'
:
tfa_optimizers
.
LAMB
,
'rmsprop'
:
tf
.
keras
.
optimizers
.
RMSprop
'rmsprop'
:
tf
.
keras
.
optimizers
.
RMSprop
,
'lars'
:
lars_optimizer
.
LARS
,
}
}
LR_CLS
=
{
LR_CLS
=
{
...
...
official/modeling/optimization/optimizer_factory_test.py
View file @
aba78478
...
@@ -23,7 +23,9 @@ from official.modeling.optimization.configs import optimization_config
...
@@ -23,7 +23,9 @@ from official.modeling.optimization.configs import optimization_config
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
OptimizerFactoryTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
((
'sgd'
),
(
'rmsprop'
),
(
'adam'
),
(
'adamw'
),
(
'lamb'
))
@
parameterized
.
parameters
((
'sgd'
),
(
'rmsprop'
),
(
'adam'
),
(
'adamw'
),
(
'lamb'
),
(
'lars'
))
def
test_optimizers
(
self
,
optimizer_type
):
def
test_optimizers
(
self
,
optimizer_type
):
params
=
{
params
=
{
'optimizer'
:
{
'optimizer'
:
{
...
...
official/vision/beta/projects/simclr/README.md
0 → 100644
View file @
aba78478
# Simple Framework for Contrastive Learning
[

](https://arxiv.org/abs/2002.05709)
[

](https://arxiv.org/abs/2006.10029)
<div
align=
"center"
>
<img
width=
"50%"
alt=
"SimCLR Illustration"
src=
"https://1.bp.blogspot.com/--vH4PKpE9Yo/Xo4a2BYervI/AAAAAAAAFpM/vaFDwPXOyAokAC8Xh852DzOgEs22NhbXwCLcBGAsYHQ/s1600/image4.gif"
>
</div>
<div
align=
"center"
>
An illustration of SimCLR (from
<a
href=
"https://ai.googleblog.com/2020/04/advancing-self-supervised-and-semi.html"
>
our blog here
</a>
).
</div>
## Enviroment setup
The code can be run on multiple GPUs or TPUs with different distribution
strategies. See the TensorFlow distributed training
[
guide
](
https://www.tensorflow.org/guide/distributed_training
)
for an overview
of
`tf.distribute`
.
The code is compatible with TensorFlow 2.4+. See requirements.txt for all
prerequisites, and you can also install them using the following command.
`pip
install -r ./official/requirements.txt`
## Pretraining
To pretrain the model on Imagenet, try the following command:
```
python3 -m official.vision.beta.projects.simclr.train \
--mode=train_and_eval \
--experiment=simclr_pretraining \
--model_dir={MODEL_DIR} \
--config_file={CONFIG_FILE}
```
An example of the config file can be found
[
here
](
./configs/experiments/imagenet_simclr_pretrain_gpu.yaml
)
## Semi-supervised learning and fine-tuning the whole network
You can access 1% and 10% ImageNet subsets used for semi-supervised learning via
[
tensorflow datasets
](
https://www.tensorflow.org/datasets/catalog/imagenet2012_subset
)
.
You can also find image IDs of these subsets in
`imagenet_subsets/`
.
To fine-tune the whole network, refer to the following command:
```
python3 -m official.vision.beta.projects.simclr.train \
--mode=train_and_eval \
--experiment=simclr_finetuning \
--model_dir={MODEL_DIR} \
--config_file={CONFIG_FILE}
```
An example of the config file can be found
[
here
](
./configs/experiments/imagenet_simclr_finetune_gpu.yaml
)
.
## Cite
[
SimCLR paper
](
https://arxiv.org/abs/2002.05709
)
:
```
@article{chen2020simple,
title={A Simple Framework for Contrastive Learning of Visual Representations},
author={Chen, Ting and Kornblith, Simon and Norouzi, Mohammad and Hinton, Geoffrey},
journal={arXiv preprint arXiv:2002.05709},
year={2020}
}
```
[
SimCLRv2 paper
](
https://arxiv.org/abs/2006.10029
)
:
```
@article{chen2020big,
title={Big Self-Supervised Models are Strong Semi-Supervised Learners},
author={Chen, Ting and Kornblith, Simon and Swersky, Kevin and Norouzi, Mohammad and Hinton, Geoffrey},
journal={arXiv preprint arXiv:2006.10029},
year={2020}
}
```
official/vision/beta/projects/simclr/common/registry_imports.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
from
official.common
import
registry_imports
from
official.vision.beta.projects.simclr.configs
import
simclr
from
official.vision.beta.projects.simclr.losses
import
contrastive_losses
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.tasks
import
simclr
as
simclr_task
official/vision/beta/projects/simclr/configs/experiments/cifar_simclr_pretrain.yaml
0 → 100644
View file @
aba78478
# Cifar classification.
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float16'
loss_scale
:
'
dynamic'
num_gpus
:
16
task
:
model
:
mode
:
'
pretrain'
input_size
:
[
32
,
32
,
3
]
backbone
:
type
:
'
resnet'
resnet
:
model_id
:
50
backbone_trainable
:
true
projection_head
:
proj_output_dim
:
64
num_proj_layers
:
2
ft_proj_idx
:
1
supervised_head
:
num_classes
:
10
norm_activation
:
use_sync_bn
:
true
norm_momentum
:
0.9
norm_epsilon
:
0.00001
loss
:
projection_norm
:
true
temperature
:
0.2
evaluation
:
top_k
:
5
one_hot
:
true
train_data
:
tfds_name
:
'
cifar10'
tfds_split
:
'
train'
input_path
:
'
'
is_training
:
true
global_batch_size
:
512
dtype
:
'
float16'
parser
:
mode
:
'
pretrain'
aug_color_jitter_strength
:
0.5
aug_rand_blur
:
false
decoder
:
decode_label
:
true
validation_data
:
tfds_name
:
'
cifar10'
tfds_split
:
'
test'
input_path
:
'
'
is_training
:
false
global_batch_size
:
512
dtype
:
'
float16'
drop_remainder
:
false
parser
:
mode
:
'
pretrain'
decoder
:
decode_label
:
true
trainer
:
train_steps
:
48000
# 500 epochs
validation_steps
:
18
# NUM_EXAMPLES (10000) // global_batch_size
validation_interval
:
96
steps_per_loop
:
96
# NUM_EXAMPLES (50000) // global_batch_size
summary_interval
:
96
checkpoint_interval
:
96
optimizer_config
:
optimizer
:
type
:
'
lars'
lars
:
momentum
:
0.9
weight_decay_rate
:
0.000001
exclude_from_weight_decay
:
[
'
batch_normalization'
,
'
bias'
]
learning_rate
:
type
:
'
cosine'
cosine
:
initial_learning_rate
:
0.6
# 0.3 × BatchSize / 256
decay_steps
:
43200
# train_steps - warmup_steps
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
4800
# 10% of total epochs
official/vision/beta/projects/simclr/configs/experiments/imagenet_simclr_finetune_gpu.yaml
0 → 100644
View file @
aba78478
# ImageNet classification.
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float16'
loss_scale
:
'
dynamic'
num_gpus
:
16
task
:
model
:
mode
:
'
finetune'
input_size
:
[
224
,
224
,
3
]
backbone
:
type
:
'
resnet'
resnet
:
model_id
:
50
backbone_trainable
:
true
projection_head
:
proj_output_dim
:
128
num_proj_layers
:
3
ft_proj_idx
:
1
supervised_head
:
num_classes
:
1001
zero_init
:
true
norm_activation
:
use_sync_bn
:
false
norm_momentum
:
0.9
norm_epsilon
:
0.00001
loss
:
label_smoothing
:
0.0
one_hot
:
true
evaluation
:
top_k
:
5
one_hot
:
true
init_checkpoint
:
'
/placer/prod/scratch/home/tf-model-garden-dev/vision/simclr/r50_1x/2021-03-26'
init_checkpoint_modules
:
'
backbone_projection'
train_data
:
tfds_name
:
'
imagenet2012_subset/10pct'
tfds_split
:
'
train'
input_path
:
'
'
is_training
:
true
global_batch_size
:
1024
dtype
:
'
float16'
parser
:
mode
:
'
finetune'
validation_data
:
tfds_name
:
'
imagenet2012_subset/10pct'
tfds_split
:
'
validation'
input_path
:
'
'
is_training
:
false
global_batch_size
:
1024
dtype
:
'
float16'
drop_remainder
:
false
parser
:
mode
:
'
finetune'
trainer
:
train_steps
:
12500
# 100 epochs
validation_steps
:
49
# NUM_EXAMPLES (50000) // global_batch_size
validation_interval
:
125
steps_per_loop
:
125
# NUM_EXAMPLES (1281167) // global_batch_size
summary_interval
:
125
checkpoint_interval
:
125
optimizer_config
:
optimizer
:
type
:
'
lars'
lars
:
momentum
:
0.9
weight_decay_rate
:
0.0
exclude_from_weight_decay
:
[
'
batch_normalization'
,
'
bias'
]
learning_rate
:
type
:
'
cosine'
cosine
:
initial_learning_rate
:
0.04
# 0.01 × BatchSize / 512
decay_steps
:
12500
# train_steps
official/vision/beta/projects/simclr/configs/experiments/imagenet_simclr_pretrain_gpu.yaml
0 → 100644
View file @
aba78478
# ImageNet classification.
runtime
:
distribution_strategy
:
'
mirrored'
mixed_precision_dtype
:
'
float16'
loss_scale
:
'
dynamic'
num_gpus
:
16
task
:
model
:
mode
:
'
pretrain'
input_size
:
[
224
,
224
,
3
]
backbone
:
type
:
'
resnet'
resnet
:
model_id
:
50
backbone_trainable
:
true
projection_head
:
proj_output_dim
:
128
num_proj_layers
:
3
ft_proj_idx
:
0
supervised_head
:
num_classes
:
1001
norm_activation
:
use_sync_bn
:
true
norm_momentum
:
0.9
norm_epsilon
:
0.00001
loss
:
projection_norm
:
true
temperature
:
0.1
evaluation
:
top_k
:
5
one_hot
:
true
train_data
:
input_path
:
'
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/train*'
is_training
:
true
global_batch_size
:
2048
dtype
:
'
float16'
parser
:
mode
:
'
pretrain'
decoder
:
decode_label
:
true
validation_data
:
input_path
:
'
/readahead/200M/placer/prod/home/distbelief/imagenet-tensorflow/imagenet-2012-tfrecord/valid*'
is_training
:
false
global_batch_size
:
2048
dtype
:
'
float16'
drop_remainder
:
false
parser
:
mode
:
'
pretrain'
decoder
:
decode_label
:
true
trainer
:
train_steps
:
187200
# 300 epochs
validation_steps
:
24
# NUM_EXAMPLES (50000) // global_batch_size
validation_interval
:
624
steps_per_loop
:
624
# NUM_EXAMPLES (1281167) // global_batch_size
summary_interval
:
624
checkpoint_interval
:
624
optimizer_config
:
optimizer
:
type
:
'
lars'
lars
:
momentum
:
0.9
weight_decay_rate
:
0.000001
exclude_from_weight_decay
:
[
'
batch_normalization'
,
'
bias'
]
learning_rate
:
type
:
'
cosine'
cosine
:
initial_learning_rate
:
1.6
# 0.2 * BatchSize / 256
decay_steps
:
177840
# train_steps - warmup_steps
warmup
:
type
:
'
linear'
linear
:
warmup_steps
:
9360
# 5% of total epochs
official/vision/beta/projects/simclr/configs/simclr.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations."""
import
os
from
typing
import
List
,
Optional
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.vision.beta.configs
import
backbones
from
official.vision.beta.configs
import
common
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
@
dataclasses
.
dataclass
class
Decoder
(
hyperparams
.
Config
):
decode_label
:
bool
=
True
@
dataclasses
.
dataclass
class
Parser
(
hyperparams
.
Config
):
"""Parser config."""
aug_rand_crop
:
bool
=
True
aug_rand_hflip
:
bool
=
True
aug_color_distort
:
bool
=
True
aug_color_jitter_strength
:
float
=
1.0
aug_color_jitter_impl
:
str
=
'simclrv2'
# 'simclrv1' or 'simclrv2'
aug_rand_blur
:
bool
=
True
parse_label
:
bool
=
True
test_crop
:
bool
=
True
mode
:
str
=
simclr_model
.
PRETRAIN
@
dataclasses
.
dataclass
class
DataConfig
(
cfg
.
DataConfig
):
"""Training data config."""
input_path
:
str
=
''
global_batch_size
:
int
=
0
is_training
:
bool
=
True
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
10000
cycle_length
:
int
=
10
# simclr specific configs
parser
:
Parser
=
Parser
()
decoder
:
Decoder
=
Decoder
()
@
dataclasses
.
dataclass
class
ProjectionHead
(
hyperparams
.
Config
):
proj_output_dim
:
int
=
128
num_proj_layers
:
int
=
3
ft_proj_idx
:
int
=
1
# layer of the projection head to use for fine-tuning.
@
dataclasses
.
dataclass
class
SupervisedHead
(
hyperparams
.
Config
):
num_classes
:
int
=
1001
zero_init
:
bool
=
False
@
dataclasses
.
dataclass
class
ContrastiveLoss
(
hyperparams
.
Config
):
projection_norm
:
bool
=
True
temperature
:
float
=
0.1
l2_weight_decay
:
float
=
0.0
@
dataclasses
.
dataclass
class
ClassificationLosses
(
hyperparams
.
Config
):
label_smoothing
:
float
=
0.0
one_hot
:
bool
=
True
l2_weight_decay
:
float
=
0.0
@
dataclasses
.
dataclass
class
Evaluation
(
hyperparams
.
Config
):
top_k
:
int
=
5
one_hot
:
bool
=
True
@
dataclasses
.
dataclass
class
SimCLRModel
(
hyperparams
.
Config
):
"""SimCLR model config."""
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
projection_head
:
ProjectionHead
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
supervised_head
:
SupervisedHead
=
SupervisedHead
(
num_classes
=
1001
)
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
mode
:
str
=
simclr_model
.
PRETRAIN
backbone_trainable
:
bool
=
True
@
dataclasses
.
dataclass
class
SimCLRPretrainTask
(
cfg
.
TaskConfig
):
"""SimCLR pretraining task config."""
model
:
SimCLRModel
=
SimCLRModel
(
mode
=
simclr_model
.
PRETRAIN
)
train_data
:
DataConfig
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
PRETRAIN
),
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
PRETRAIN
),
is_training
=
False
)
loss
:
ContrastiveLoss
=
ContrastiveLoss
()
evaluation
:
Evaluation
=
Evaluation
()
init_checkpoint
:
Optional
[
str
]
=
None
# all or backbone
init_checkpoint_modules
:
str
=
'all'
@
dataclasses
.
dataclass
class
SimCLRFinetuneTask
(
cfg
.
TaskConfig
):
"""SimCLR fine tune task config."""
model
:
SimCLRModel
=
SimCLRModel
(
mode
=
simclr_model
.
FINETUNE
,
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
))
train_data
:
DataConfig
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
FINETUNE
),
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
FINETUNE
),
is_training
=
False
)
loss
:
ClassificationLosses
=
ClassificationLosses
()
evaluation
:
Evaluation
=
Evaluation
()
init_checkpoint
:
Optional
[
str
]
=
None
# all, backbone_projection or backbone
init_checkpoint_modules
:
str
=
'backbone_projection'
@
exp_factory
.
register_config_factory
(
'simclr_pretraining'
)
def
simclr_pretraining
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
return
cfg
.
ExperimentConfig
(
task
=
SimCLRPretrainTask
(),
trainer
=
cfg
.
TrainerConfig
(),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
@
exp_factory
.
register_config_factory
(
'simclr_finetuning'
)
def
simclr_finetuning
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
return
cfg
.
ExperimentConfig
(
task
=
SimCLRFinetuneTask
(),
trainer
=
cfg
.
TrainerConfig
(),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
IMAGENET_TRAIN_EXAMPLES
=
1281167
IMAGENET_VAL_EXAMPLES
=
50000
IMAGENET_INPUT_PATH_BASE
=
'imagenet-2012-tfrecord'
@
exp_factory
.
register_config_factory
(
'simclr_pretraining_imagenet'
)
def
simclr_pretraining_imagenet
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
train_batch_size
=
4096
eval_batch_size
=
4096
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
return
cfg
.
ExperimentConfig
(
task
=
SimCLRPretrainTask
(
model
=
SimCLRModel
(
mode
=
simclr_model
.
PRETRAIN
,
backbone_trainable
=
True
,
input_size
=
[
224
,
224
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
),
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
True
)),
loss
=
ContrastiveLoss
(),
evaluation
=
Evaluation
(),
train_data
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
PRETRAIN
),
decoder
=
Decoder
(
decode_label
=
True
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
PRETRAIN
),
decoder
=
Decoder
(
decode_label
=
True
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
),
),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
500
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'lars'
,
'lars'
:
{
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.000001
,
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
# 0.2 * BatchSize / 256
'initial_learning_rate'
:
0.2
*
train_batch_size
/
256
,
# train_steps - warmup_steps
'decay_steps'
:
475
*
steps_per_epoch
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
# 5% of total epochs
'warmup_steps'
:
25
*
steps_per_epoch
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
@
exp_factory
.
register_config_factory
(
'simclr_finetuning_imagenet'
)
def
simclr_finetuning_imagenet
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
train_batch_size
=
1024
eval_batch_size
=
1024
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
pretrain_model_base
=
''
return
cfg
.
ExperimentConfig
(
task
=
SimCLRFinetuneTask
(
model
=
SimCLRModel
(
mode
=
simclr_model
.
FINETUNE
,
backbone_trainable
=
True
,
input_size
=
[
224
,
224
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
),
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)),
loss
=
ClassificationLosses
(),
evaluation
=
Evaluation
(),
train_data
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
FINETUNE
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
parser
=
Parser
(
mode
=
simclr_model
.
FINETUNE
),
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
),
init_checkpoint
=
pretrain_model_base
,
# all, backbone_projection or backbone
init_checkpoint_modules
=
'backbone_projection'
),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
60
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'lars'
,
'lars'
:
{
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.0
,
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
# 0.01 × BatchSize / 512
'initial_learning_rate'
:
0.01
*
train_batch_size
/
512
,
'decay_steps'
:
60
*
steps_per_epoch
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
official/vision/beta/projects/simclr/configs/simclr_test.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.vision.beta.projects.simclr.common
import
registry_imports
# pylint: disable=unused-import
from
official.vision.beta.projects.simclr.configs
import
simclr
as
exp_cfg
class
SimCLRConfigTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
'simclr_pretraining_imagenet'
,
'simclr_finetuning_imagenet'
)
def
test_simclr_configs
(
self
,
config_name
):
config
=
exp_factory
.
get_exp_config
(
config_name
)
self
.
assertIsInstance
(
config
,
cfg
.
ExperimentConfig
)
if
config_name
==
'simclr_pretrain_imagenet'
:
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
SimCLRPretrainTask
)
elif
config_name
==
'simclr_finetuning_imagenet'
:
self
.
assertIsInstance
(
config
.
task
,
exp_cfg
.
SimCLRFinetuneTask
)
self
.
assertIsInstance
(
config
.
task
.
model
,
exp_cfg
.
SimCLRModel
)
self
.
assertIsInstance
(
config
.
task
.
train_data
,
exp_cfg
.
DataConfig
)
config
.
task
.
train_data
.
is_training
=
None
with
self
.
assertRaises
(
KeyError
):
config
.
validate
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops."""
import
functools
import
tensorflow
as
tf
CROP_PROPORTION
=
0.875
# Standard for ImageNet.
def
random_apply
(
func
,
p
,
x
):
"""Randomly apply function func to x with probability p."""
return
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
([],
minval
=
0
,
maxval
=
1
,
dtype
=
tf
.
float32
),
tf
.
cast
(
p
,
tf
.
float32
)),
lambda
:
func
(
x
),
lambda
:
x
)
def
random_brightness
(
image
,
max_delta
,
impl
=
'simclrv2'
):
"""A multiplicative vs additive change of brightness."""
if
impl
==
'simclrv2'
:
factor
=
tf
.
random
.
uniform
([],
tf
.
maximum
(
1.0
-
max_delta
,
0
),
1.0
+
max_delta
)
image
=
image
*
factor
elif
impl
==
'simclrv1'
:
image
=
tf
.
image
.
random_brightness
(
image
,
max_delta
=
max_delta
)
else
:
raise
ValueError
(
'Unknown impl {} for random brightness.'
.
format
(
impl
))
return
image
def
to_grayscale
(
image
,
keep_channels
=
True
):
image
=
tf
.
image
.
rgb_to_grayscale
(
image
)
if
keep_channels
:
image
=
tf
.
tile
(
image
,
[
1
,
1
,
3
])
return
image
def
color_jitter_nonrand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
,
impl
=
'simclrv2'
):
"""Distorts the color of the image (jittering order is fixed).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
,
brightness
,
contrast
,
saturation
,
hue
):
"""Apply the i-th transformation."""
if
brightness
!=
0
and
i
==
0
:
x
=
random_brightness
(
x
,
max_delta
=
brightness
,
impl
=
impl
)
elif
contrast
!=
0
and
i
==
1
:
x
=
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
elif
saturation
!=
0
and
i
==
2
:
x
=
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
elif
hue
!=
0
:
x
=
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
return
x
for
i
in
range
(
4
):
image
=
apply_transform
(
i
,
image
,
brightness
,
contrast
,
saturation
,
hue
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
color_jitter_rand
(
image
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
,
impl
=
'simclrv2'
):
"""Distorts the color of the image (jittering order is random).
Args:
image: The input image tensor.
brightness: A float, specifying the brightness for color jitter.
contrast: A float, specifying the contrast for color jitter.
saturation: A float, specifying the saturation for color jitter.
hue: A float, specifying the hue for color jitter.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
with
tf
.
name_scope
(
'distort_color'
):
def
apply_transform
(
i
,
x
):
"""Apply the i-th transformation."""
def
brightness_foo
():
if
brightness
==
0
:
return
x
else
:
return
random_brightness
(
x
,
max_delta
=
brightness
,
impl
=
impl
)
def
contrast_foo
():
if
contrast
==
0
:
return
x
else
:
return
tf
.
image
.
random_contrast
(
x
,
lower
=
1
-
contrast
,
upper
=
1
+
contrast
)
def
saturation_foo
():
if
saturation
==
0
:
return
x
else
:
return
tf
.
image
.
random_saturation
(
x
,
lower
=
1
-
saturation
,
upper
=
1
+
saturation
)
def
hue_foo
():
if
hue
==
0
:
return
x
else
:
return
tf
.
image
.
random_hue
(
x
,
max_delta
=
hue
)
x
=
tf
.
cond
(
tf
.
less
(
i
,
2
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
1
),
brightness_foo
,
contrast_foo
),
lambda
:
tf
.
cond
(
tf
.
less
(
i
,
3
),
saturation_foo
,
hue_foo
))
return
x
perm
=
tf
.
random
.
shuffle
(
tf
.
range
(
4
))
for
i
in
range
(
4
):
image
=
apply_transform
(
perm
[
i
],
image
)
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
return
image
def
color_jitter
(
image
,
strength
,
random_order
=
True
,
impl
=
'simclrv2'
):
"""Distorts the color of the image.
Args:
image: The input image tensor.
strength: the floating number for the strength of the color augmentation.
random_order: A bool, specifying whether to randomize the jittering order.
impl: 'simclrv1' or 'simclrv2'. Whether to use simclrv1 or simclrv2's
version of random brightness.
Returns:
The distorted image tensor.
"""
brightness
=
0.8
*
strength
contrast
=
0.8
*
strength
saturation
=
0.8
*
strength
hue
=
0.2
*
strength
if
random_order
:
return
color_jitter_rand
(
image
,
brightness
,
contrast
,
saturation
,
hue
,
impl
=
impl
)
else
:
return
color_jitter_nonrand
(
image
,
brightness
,
contrast
,
saturation
,
hue
,
impl
=
impl
)
def
random_color_jitter
(
image
,
p
=
1.0
,
color_jitter_strength
=
1.0
,
impl
=
'simclrv2'
):
"""Perform random color jitter."""
def
_transform
(
image
):
color_jitter_t
=
functools
.
partial
(
color_jitter
,
strength
=
color_jitter_strength
,
impl
=
impl
)
image
=
random_apply
(
color_jitter_t
,
p
=
0.8
,
x
=
image
)
return
random_apply
(
to_grayscale
,
p
=
0.2
,
x
=
image
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
gaussian_blur
(
image
,
kernel_size
,
sigma
,
padding
=
'SAME'
):
"""Blurs the given image with separable convolution.
Args:
image: Tensor of shape [height, width, channels] and dtype float to blur.
kernel_size: Integer Tensor for the size of the blur kernel. This is should
be an odd number. If it is an even number, the actual kernel size will be
size + 1.
sigma: Sigma value for gaussian operator.
padding: Padding to use for the convolution. Typically 'SAME' or 'VALID'.
Returns:
A Tensor representing the blurred image.
"""
radius
=
tf
.
cast
(
kernel_size
/
2
,
dtype
=
tf
.
int32
)
kernel_size
=
radius
*
2
+
1
x
=
tf
.
cast
(
tf
.
range
(
-
radius
,
radius
+
1
),
dtype
=
tf
.
float32
)
blur_filter
=
tf
.
exp
(
-
tf
.
pow
(
x
,
2.0
)
/
(
2.0
*
tf
.
pow
(
tf
.
cast
(
sigma
,
dtype
=
tf
.
float32
),
2.0
)))
blur_filter
/=
tf
.
reduce_sum
(
blur_filter
)
# One vertical and one horizontal filter.
blur_v
=
tf
.
reshape
(
blur_filter
,
[
kernel_size
,
1
,
1
,
1
])
blur_h
=
tf
.
reshape
(
blur_filter
,
[
1
,
kernel_size
,
1
,
1
])
num_channels
=
tf
.
shape
(
image
)[
-
1
]
blur_h
=
tf
.
tile
(
blur_h
,
[
1
,
1
,
num_channels
,
1
])
blur_v
=
tf
.
tile
(
blur_v
,
[
1
,
1
,
num_channels
,
1
])
expand_batch_dim
=
image
.
shape
.
ndims
==
3
if
expand_batch_dim
:
# Tensorflow requires batched input to convolutions, which we can fake with
# an extra dimension.
image
=
tf
.
expand_dims
(
image
,
axis
=
0
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
image
,
blur_h
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
blurred
=
tf
.
nn
.
depthwise_conv2d
(
blurred
,
blur_v
,
strides
=
[
1
,
1
,
1
,
1
],
padding
=
padding
)
if
expand_batch_dim
:
blurred
=
tf
.
squeeze
(
blurred
,
axis
=
0
)
return
blurred
def
random_blur
(
image
,
height
,
width
,
p
=
0.5
):
"""Randomly blur an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
del
width
def
_transform
(
image
):
sigma
=
tf
.
random
.
uniform
([],
0.1
,
2.0
,
dtype
=
tf
.
float32
)
return
gaussian_blur
(
image
,
kernel_size
=
height
//
10
,
sigma
=
sigma
,
padding
=
'SAME'
)
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
def
distorted_bounding_box_crop
(
image
,
bbox
,
min_object_covered
=
0.1
,
aspect_ratio_range
=
(
0.75
,
1.33
),
area_range
=
(
0.05
,
1.0
),
max_attempts
=
100
,
scope
=
None
):
"""Generates cropped_image using one of the bboxes randomly distorted.
See `tf.image.sample_distorted_bounding_box` for more documentation.
Args:
image: `Tensor` of image data.
bbox: `Tensor` of bounding boxes arranged `[1, num_boxes, coords]`
where each coordinate is [0, 1) and the coordinates are arranged
as `[ymin, xmin, ymax, xmax]`. If num_boxes is 0 then use the whole
image.
min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
area of the image must contain at least this fraction of any bounding
box supplied.
aspect_ratio_range: An optional list of `float`s. The cropped area of the
image must have an aspect ratio = width / height within this range.
area_range: An optional list of `float`s. The cropped area of the image
must contain a fraction of the supplied image within in this range.
max_attempts: An optional `int`. Number of attempts at generating a cropped
region of the image of the specified constraints. After `max_attempts`
failures, return the entire image.
scope: Optional `str` for name scope.
Returns:
(cropped image `Tensor`, distorted bbox `Tensor`).
"""
with
tf
.
name_scope
(
scope
or
'distorted_bounding_box_crop'
):
shape
=
tf
.
shape
(
image
)
sample_distorted_bounding_box
=
tf
.
image
.
sample_distorted_bounding_box
(
shape
,
bounding_boxes
=
bbox
,
min_object_covered
=
min_object_covered
,
aspect_ratio_range
=
aspect_ratio_range
,
area_range
=
area_range
,
max_attempts
=
max_attempts
,
use_image_if_no_bounding_boxes
=
True
)
bbox_begin
,
bbox_size
,
_
=
sample_distorted_bounding_box
# Crop the image to the specified bounding box.
offset_y
,
offset_x
,
_
=
tf
.
unstack
(
bbox_begin
)
target_height
,
target_width
,
_
=
tf
.
unstack
(
bbox_size
)
image
=
tf
.
image
.
crop_to_bounding_box
(
image
,
offset_y
,
offset_x
,
target_height
,
target_width
)
return
image
def
crop_and_resize
(
image
,
height
,
width
):
"""Make a random crop and resize it to height `height` and width `width`.
Args:
image: Tensor representing the image.
height: Desired image height.
width: Desired image width.
Returns:
A `height` x `width` x channels Tensor holding a random crop of `image`.
"""
bbox
=
tf
.
constant
([
0.0
,
0.0
,
1.0
,
1.0
],
dtype
=
tf
.
float32
,
shape
=
[
1
,
1
,
4
])
aspect_ratio
=
width
/
height
image
=
distorted_bounding_box_crop
(
image
,
bbox
,
min_object_covered
=
0.1
,
aspect_ratio_range
=
(
3.
/
4
*
aspect_ratio
,
4.
/
3.
*
aspect_ratio
),
area_range
=
(
0.08
,
1.0
),
max_attempts
=
100
,
scope
=
None
)
return
tf
.
image
.
resize
([
image
],
[
height
,
width
],
method
=
tf
.
image
.
ResizeMethod
.
BICUBIC
)[
0
]
def
random_crop_with_resize
(
image
,
height
,
width
,
p
=
1.0
):
"""Randomly crop and resize an image.
Args:
image: `Tensor` representing an image of arbitrary size.
height: Height of output image.
width: Width of output image.
p: Probability of applying this transformation.
Returns:
A preprocessed image `Tensor`.
"""
def
_transform
(
image
):
# pylint: disable=missing-docstring
image
=
crop_and_resize
(
image
,
height
,
width
)
return
image
return
random_apply
(
_transform
,
p
=
p
,
x
=
image
)
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR.
For pre-training:
- Preprocessing:
-> random cropping
-> resize back to the original size
-> random color distortions
-> random Gaussian blur (sequential)
- Each image need to be processed randomly twice
```snippets
if train_mode == 'pretrain':
xs = []
for _ in range(2): # Two transformations
xs.append(preprocess_fn_pretrain(image))
image = tf.concat(xs, -1)
else:
image = preprocess_fn_finetune(image)
```
For fine-tuning:
typical image classification input
"""
from
typing
import
List
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.projects.simclr.dataloaders
import
preprocess_ops
as
simclr_preprocess_ops
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
class
Decoder
(
decoder
.
Decoder
):
"""A tf.Example decoder for classification task."""
def
__init__
(
self
,
decode_label
=
True
):
self
.
_decode_label
=
decode_label
self
.
_keys_to_features
=
{
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
}
if
self
.
_decode_label
:
self
.
_keys_to_features
.
update
({
'image/class/label'
:
(
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=-
1
))
})
def
decode
(
self
,
serialized_example
):
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
class
TFDSDecoder
(
decoder
.
Decoder
):
"""A TFDS decoder for classification task."""
def
__init__
(
self
,
decode_label
=
True
):
self
.
_decode_label
=
decode_label
def
decode
(
self
,
serialized_example
):
sample_dict
=
{
'image/encoded'
:
tf
.
io
.
encode_jpeg
(
serialized_example
[
'image'
],
quality
=
100
),
}
if
self
.
_decode_label
:
sample_dict
.
update
({
'image/class/label'
:
serialized_example
[
'label'
],
})
return
sample_dict
class
Parser
(
parser
.
Parser
):
"""Parser for SimCLR training."""
def
__init__
(
self
,
output_size
:
List
[
int
],
aug_rand_crop
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_color_distort
:
bool
=
True
,
aug_color_jitter_strength
:
float
=
1.0
,
aug_color_jitter_impl
:
str
=
'simclrv2'
,
aug_rand_blur
:
bool
=
True
,
parse_label
:
bool
=
True
,
test_crop
:
bool
=
True
,
mode
:
str
=
simclr_model
.
PRETRAIN
,
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
aug_rand_crop: `bool`, if Ture, augment training with random cropping.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_color_distort: `bool`, if True augment training with color distortion.
aug_color_jitter_strength: `float`, the floating number for the strength
of the color augmentation
aug_color_jitter_impl: `str`, 'simclrv1' or 'simclrv2'. Define whether
to use simclrv1 or simclrv2's version of random brightness.
aug_rand_blur: `bool`, if True, augment training with random blur.
parse_label: `bool`, if True, parse label together with image.
test_crop: `bool`, if True, augment eval with center cropping.
mode: `str`, 'pretain' or 'finetune'. Define training mode.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
self
.
_output_size
=
output_size
self
.
_aug_rand_crop
=
aug_rand_crop
self
.
_aug_rand_hflip
=
aug_rand_hflip
self
.
_aug_color_distort
=
aug_color_distort
self
.
_aug_color_jitter_strength
=
aug_color_jitter_strength
self
.
_aug_color_jitter_impl
=
aug_color_jitter_impl
self
.
_aug_rand_blur
=
aug_rand_blur
self
.
_parse_label
=
parse_label
self
.
_mode
=
mode
self
.
_test_crop
=
test_crop
if
max
(
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
<=
32
:
self
.
_test_crop
=
False
if
dtype
==
'float32'
:
self
.
_dtype
=
tf
.
float32
elif
dtype
==
'float16'
:
self
.
_dtype
=
tf
.
float16
elif
dtype
==
'bfloat16'
:
self
.
_dtype
=
tf
.
bfloat16
else
:
raise
ValueError
(
'dtype {!r} is not supported!'
.
format
(
dtype
))
def
_parse_one_train_image
(
self
,
image_bytes
):
image
=
tf
.
image
.
decode_jpeg
(
image_bytes
,
channels
=
3
)
# This line convert the image to float 0.0 - 1.0
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
if
self
.
_aug_rand_crop
:
image
=
simclr_preprocess_ops
.
random_crop_with_resize
(
image
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
if
self
.
_aug_rand_hflip
:
image
=
tf
.
image
.
random_flip_left_right
(
image
)
if
self
.
_aug_color_distort
and
self
.
_mode
==
simclr_model
.
PRETRAIN
:
image
=
simclr_preprocess_ops
.
random_color_jitter
(
image
=
image
,
color_jitter_strength
=
self
.
_aug_color_jitter_strength
,
impl
=
self
.
_aug_color_jitter_impl
)
if
self
.
_aug_rand_blur
and
self
.
_mode
==
simclr_model
.
PRETRAIN
:
image
=
simclr_preprocess_ops
.
random_blur
(
image
,
self
.
_output_size
[
0
],
self
.
_output_size
[
1
])
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
reshape
(
image
,
[
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
return
image
def
_parse_train_data
(
self
,
decoded_tensors
):
"""Parses data for training."""
image_bytes
=
decoded_tensors
[
'image/encoded'
]
if
self
.
_mode
==
simclr_model
.
FINETUNE
:
image
=
self
.
_parse_one_train_image
(
image_bytes
)
elif
self
.
_mode
==
simclr_model
.
PRETRAIN
:
# Transform each example twice using a combination of
# simple augmentations, resulting in 2N data points
xs
=
[]
for
_
in
range
(
2
):
xs
.
append
(
self
.
_parse_one_train_image
(
image_bytes
))
image
=
tf
.
concat
(
xs
,
-
1
)
else
:
raise
ValueError
(
'The mode {} is not supported by the Parser.'
.
format
(
self
.
_mode
))
if
self
.
_parse_label
:
label
=
tf
.
cast
(
decoded_tensors
[
'image/class/label'
],
dtype
=
tf
.
int32
)
return
image
,
label
return
image
def
_parse_eval_data
(
self
,
decoded_tensors
):
"""Parses data for evaluation."""
image_bytes
=
decoded_tensors
[
'image/encoded'
]
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
if
self
.
_test_crop
:
image
=
preprocess_ops
.
center_crop_image_v2
(
image_bytes
,
image_shape
)
else
:
image
=
tf
.
image
.
decode_jpeg
(
image_bytes
,
channels
=
3
)
# This line convert the image to float 0.0 - 1.0
image
=
tf
.
image
.
convert_image_dtype
(
image
,
dtype
=
tf
.
float32
)
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
reshape
(
image
,
[
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
image
=
tf
.
clip_by_value
(
image
,
0.
,
1.
)
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
if
self
.
_parse_label
:
label
=
tf
.
cast
(
decoded_tensors
[
'image/class/label'
],
dtype
=
tf
.
int32
)
return
image
,
label
return
image
official/vision/beta/projects/simclr/heads/simclr_head.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
from
typing
import
Text
,
Optional
import
tensorflow
as
tf
from
official.vision.beta.projects.simclr.modeling.layers
import
nn_blocks
regularizers
=
tf
.
keras
.
regularizers
layers
=
tf
.
keras
.
layers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'simclr'
)
class
ProjectionHead
(
tf
.
keras
.
layers
.
Layer
):
"""Projection head."""
def
__init__
(
self
,
num_proj_layers
:
int
=
3
,
proj_output_dim
:
Optional
[
int
]
=
None
,
ft_proj_idx
:
int
=
0
,
kernel_initializer
:
Text
=
'VarianceScaling'
,
kernel_regularizer
:
Optional
[
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
regularizers
.
Regularizer
]
=
None
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
**
kwargs
):
"""The projection head used during pretraining of SimCLR.
Args:
num_proj_layers: `int` number of Dense layers used.
proj_output_dim: `int` output dimension of projection head, i.e., output
dimension of the final layer.
ft_proj_idx: `int` index of layer to use during fine-tuning. 0 means no
projection head during fine tuning, -1 means the final layer.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization omentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
**kwargs: keyword arguments to be passed.
"""
super
(
ProjectionHead
,
self
).
__init__
(
**
kwargs
)
assert
proj_output_dim
is
not
None
or
num_proj_layers
==
0
assert
ft_proj_idx
<=
num_proj_layers
,
(
num_proj_layers
,
ft_proj_idx
)
self
.
_proj_output_dim
=
proj_output_dim
self
.
_num_proj_layers
=
num_proj_layers
self
.
_ft_proj_idx
=
ft_proj_idx
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_layers
=
[]
def
get_config
(
self
):
config
=
{
'proj_output_dim'
:
self
.
_proj_output_dim
,
'num_proj_layers'
:
self
.
_num_proj_layers
,
'ft_proj_idx'
:
self
.
_ft_proj_idx
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'use_normalization'
:
self
.
_use_normalization
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
}
base_config
=
super
(
ProjectionHead
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
self
.
_layers
=
[]
if
self
.
_num_proj_layers
>
0
:
intermediate_dim
=
int
(
input_shape
[
-
1
])
for
j
in
range
(
self
.
_num_proj_layers
):
if
j
!=
self
.
_num_proj_layers
-
1
:
# for the middle layers, use bias and relu for the output.
layer
=
nn_blocks
.
DenseBN
(
output_dim
=
intermediate_dim
,
use_bias
=
True
,
use_normalization
=
True
,
activation
=
'relu'
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
,
name
=
'nl_%d'
%
j
)
else
:
# for the final layer, neither bias nor relu is used.
layer
=
nn_blocks
.
DenseBN
(
output_dim
=
self
.
_proj_output_dim
,
use_bias
=
False
,
use_normalization
=
True
,
activation
=
None
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
use_sync_bn
=
self
.
_use_sync_bn
,
norm_momentum
=
self
.
_norm_momentum
,
norm_epsilon
=
self
.
_norm_epsilon
,
name
=
'nl_%d'
%
j
)
self
.
_layers
.
append
(
layer
)
super
(
ProjectionHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
hiddens_list
=
[
tf
.
identity
(
inputs
,
'proj_head_input'
)]
if
self
.
_num_proj_layers
==
0
:
proj_head_output
=
inputs
proj_finetune_output
=
inputs
else
:
for
j
in
range
(
self
.
_num_proj_layers
):
hiddens
=
self
.
_layers
[
j
](
hiddens_list
[
-
1
],
training
)
hiddens_list
.
append
(
hiddens
)
proj_head_output
=
tf
.
identity
(
hiddens_list
[
-
1
],
'proj_head_output'
)
proj_finetune_output
=
tf
.
identity
(
hiddens_list
[
self
.
_ft_proj_idx
],
'proj_finetune_output'
)
# The first element is the output of the projection head.
# The second element is the input of the finetune head.
return
proj_head_output
,
proj_finetune_output
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'simclr'
)
class
ClassificationHead
(
tf
.
keras
.
layers
.
Layer
):
"""Classification Head."""
def
__init__
(
self
,
num_classes
:
int
,
kernel_initializer
:
Text
=
'random_uniform'
,
kernel_regularizer
:
Optional
[
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
regularizers
.
Regularizer
]
=
None
,
name
:
Text
=
'head_supervised'
,
**
kwargs
):
"""The classification head used during pretraining or fine tuning.
Args:
num_classes: `int` size of the output dimension or number of classes
for classification task.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
name: `str`, name of the layer.
**kwargs: keyword arguments to be passed.
"""
super
(
ClassificationHead
,
self
).
__init__
(
name
=
name
,
**
kwargs
)
self
.
_num_classes
=
num_classes
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_name
=
name
def
get_config
(
self
):
config
=
{
'num_classes'
:
self
.
_num_classes
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
}
base_config
=
super
(
ClassificationHead
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
self
.
_dense0
=
layers
.
Dense
(
units
=
self
.
_num_classes
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activation
=
None
)
super
(
ClassificationHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
inputs
=
self
.
_dense0
(
inputs
)
return
inputs
official/vision/beta/projects/simclr/heads/simclr_head_test.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.projects.simclr.heads
import
simclr_head
class
ProjectionHeadTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
0
,
None
),
(
1
,
128
),
(
2
,
128
),
)
def
test_head_creation
(
self
,
num_proj_layers
,
proj_output_dim
):
test_layer
=
simclr_head
.
ProjectionHead
(
num_proj_layers
=
num_proj_layers
,
proj_output_dim
=
proj_output_dim
)
input_dim
=
64
x
=
tf
.
keras
.
Input
(
shape
=
(
input_dim
,))
proj_head_output
,
proj_finetune_output
=
test_layer
(
x
)
proj_head_output_dim
=
input_dim
if
num_proj_layers
>
0
:
proj_head_output_dim
=
proj_output_dim
self
.
assertAllEqual
(
proj_head_output
.
shape
.
as_list
(),
[
None
,
proj_head_output_dim
])
if
num_proj_layers
>
0
:
proj_finetune_output_dim
=
input_dim
self
.
assertAllEqual
(
proj_finetune_output
.
shape
.
as_list
(),
[
None
,
proj_finetune_output_dim
])
@
parameterized
.
parameters
(
(
0
,
None
,
0
),
(
1
,
128
,
0
),
(
2
,
128
,
1
),
(
2
,
128
,
2
),
)
def
test_outputs
(
self
,
num_proj_layers
,
proj_output_dim
,
ft_proj_idx
):
test_layer
=
simclr_head
.
ProjectionHead
(
num_proj_layers
=
num_proj_layers
,
proj_output_dim
=
proj_output_dim
,
ft_proj_idx
=
ft_proj_idx
)
input_dim
=
64
batch_size
=
2
inputs
=
np
.
random
.
rand
(
batch_size
,
input_dim
)
proj_head_output
,
proj_finetune_output
=
test_layer
(
inputs
)
if
num_proj_layers
==
0
:
self
.
assertAllClose
(
inputs
,
proj_head_output
)
self
.
assertAllClose
(
inputs
,
proj_finetune_output
)
else
:
self
.
assertAllEqual
(
proj_head_output
.
shape
.
as_list
(),
[
batch_size
,
proj_output_dim
])
if
ft_proj_idx
==
0
:
self
.
assertAllClose
(
inputs
,
proj_finetune_output
)
elif
ft_proj_idx
<
num_proj_layers
:
self
.
assertAllEqual
(
proj_finetune_output
.
shape
.
as_list
(),
[
batch_size
,
input_dim
])
else
:
self
.
assertAllEqual
(
proj_finetune_output
.
shape
.
as_list
(),
[
batch_size
,
proj_output_dim
])
class
ClassificationHeadTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
10
,
20
)
def
test_head_creation
(
self
,
num_classes
):
test_layer
=
simclr_head
.
ClassificationHead
(
num_classes
=
num_classes
)
input_dim
=
64
x
=
tf
.
keras
.
Input
(
shape
=
(
input_dim
,))
out_x
=
test_layer
(
x
)
self
.
assertAllEqual
(
out_x
.
shape
.
as_list
(),
[
None
,
num_classes
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/simclr/losses/contrastive_losses.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions."""
import
functools
import
tensorflow
as
tf
LARGE_NUM
=
1e9
def
cross_replica_concat
(
tensor
:
tf
.
Tensor
,
num_replicas
:
int
)
->
tf
.
Tensor
:
"""Reduce a concatenation of the `tensor` across multiple replicas.
Args:
tensor: `tf.Tensor` to concatenate.
num_replicas: `int` number of replicas.
Returns:
Tensor of the same rank as `tensor` with first dimension `num_replicas`
times larger.
"""
if
num_replicas
<=
1
:
return
tensor
replica_context
=
tf
.
distribute
.
get_replica_context
()
with
tf
.
name_scope
(
'cross_replica_concat'
):
# This creates a tensor that is like the input tensor but has an added
# replica dimension as the outermost dimension. On each replica it will
# contain the local values and zeros for all other values that need to be
# fetched from other replicas.
ext_tensor
=
tf
.
scatter_nd
(
indices
=
[[
replica_context
.
replica_id_in_sync_group
]],
updates
=
[
tensor
],
shape
=
tf
.
concat
([[
num_replicas
],
tf
.
shape
(
tensor
)],
axis
=
0
))
# As every value is only present on one replica and 0 in all others, adding
# them all together will result in the full tensor on all replicas.
ext_tensor
=
replica_context
.
all_reduce
(
tf
.
distribute
.
ReduceOp
.
SUM
,
ext_tensor
)
# Flatten the replica dimension.
# The first dimension size will be: tensor.shape[0] * num_replicas
# Using [-1] trick to support also scalar input.
return
tf
.
reshape
(
ext_tensor
,
[
-
1
]
+
ext_tensor
.
shape
.
as_list
()[
2
:])
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'simclr'
)
class
ContrastiveLoss
(
object
):
"""Contrastive training loss function."""
def
__init__
(
self
,
projection_norm
:
bool
=
True
,
temperature
:
float
=
1.0
):
"""Initializes `ContrastiveLoss`.
Args:
projection_norm: whether or not to use normalization on the hidden vector.
temperature: a `floating` number for temperature scaling.
"""
self
.
_projection_norm
=
projection_norm
self
.
_temperature
=
temperature
def
__call__
(
self
,
projection1
:
tf
.
Tensor
,
projection2
:
tf
.
Tensor
):
"""Compute the contrastive loss for contrastive learning.
Note that projection2 is generated with the same batch (same order) of raw
images, but with different augmentation. More specifically:
image[i] -> random augmentation 1 -> projection -> projection1[i]
image[i] -> random augmentation 2 -> projection -> projection2[i]
Args:
projection1: projection vector of shape (bsz, dim).
projection2: projection vector of shape (bsz, dim).
Returns:
A loss scalar.
The logits for contrastive prediction task.
The labels for contrastive prediction task.
"""
# Get (normalized) hidden1 and hidden2.
if
self
.
_projection_norm
:
projection1
=
tf
.
math
.
l2_normalize
(
projection1
,
-
1
)
projection2
=
tf
.
math
.
l2_normalize
(
projection2
,
-
1
)
batch_size
=
tf
.
shape
(
projection1
)[
0
]
p1_local
,
p2_local
=
projection1
,
projection2
# Gather projection1/projection2 across replicas and create local labels.
num_replicas_in_sync
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
if
num_replicas_in_sync
>
1
:
p1_global
=
cross_replica_concat
(
p1_local
,
num_replicas_in_sync
)
p2_global
=
cross_replica_concat
(
p2_local
,
num_replicas_in_sync
)
global_batch_size
=
tf
.
shape
(
p1_global
)[
0
]
replica_context
=
tf
.
distribute
.
get_replica_context
()
replica_id
=
tf
.
cast
(
tf
.
cast
(
replica_context
.
replica_id_in_sync_group
,
tf
.
uint32
),
tf
.
int32
)
labels_idx
=
tf
.
range
(
batch_size
)
+
replica_id
*
batch_size
labels
=
tf
.
one_hot
(
labels_idx
,
global_batch_size
*
2
)
masks
=
tf
.
one_hot
(
labels_idx
,
global_batch_size
)
else
:
p1_global
=
p1_local
p2_global
=
p2_local
labels
=
tf
.
one_hot
(
tf
.
range
(
batch_size
),
batch_size
*
2
)
masks
=
tf
.
one_hot
(
tf
.
range
(
batch_size
),
batch_size
)
tb_matmul
=
functools
.
partial
(
tf
.
matmul
,
transpose_b
=
True
)
logits_aa
=
tb_matmul
(
p1_local
,
p1_global
)
/
self
.
_temperature
logits_aa
=
logits_aa
-
masks
*
LARGE_NUM
logits_bb
=
tb_matmul
(
p2_local
,
p2_global
)
/
self
.
_temperature
logits_bb
=
logits_bb
-
masks
*
LARGE_NUM
logits_ab
=
tb_matmul
(
p1_local
,
p2_global
)
/
self
.
_temperature
logits_ba
=
tb_matmul
(
p2_local
,
p1_global
)
/
self
.
_temperature
loss_a_local
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
tf
.
concat
([
logits_ab
,
logits_aa
],
1
))
loss_b_local
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
tf
.
concat
([
logits_ba
,
logits_bb
],
1
))
loss_local
=
tf
.
reduce_mean
(
loss_a_local
+
loss_b_local
)
return
loss_local
,
(
logits_ab
,
labels
)
def
get_config
(
self
):
config
=
{
'projection_norm'
:
self
.
_projection_norm
,
'temperature'
:
self
.
_temperature
,
}
return
config
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.projects.simclr.losses
import
contrastive_losses
class
ContrastiveLossesTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
1.0
,
0.5
)
def
test_contrastive_loss_computation
(
self
,
temperature
):
batch_size
=
2
project_dim
=
16
projection_norm
=
False
p_1_arr
=
np
.
random
.
rand
(
batch_size
,
project_dim
)
p_1
=
tf
.
constant
(
p_1_arr
,
dtype
=
tf
.
float32
)
p_2_arr
=
np
.
random
.
rand
(
batch_size
,
project_dim
)
p_2
=
tf
.
constant
(
p_2_arr
,
dtype
=
tf
.
float32
)
losses_obj
=
contrastive_losses
.
ContrastiveLoss
(
projection_norm
=
projection_norm
,
temperature
=
temperature
)
comp_contrastive_loss
=
losses_obj
(
projection1
=
p_1
,
projection2
=
p_2
)
def
_exp_sim
(
p1
,
p2
):
return
np
.
exp
(
np
.
matmul
(
p1
,
p2
)
/
temperature
)
l11
=
-
np
.
log
(
_exp_sim
(
p_1_arr
[
0
],
p_2_arr
[
0
])
/
(
_exp_sim
(
p_1_arr
[
0
],
p_1_arr
[
1
])
+
_exp_sim
(
p_1_arr
[
0
],
p_2_arr
[
1
])
+
_exp_sim
(
p_1_arr
[
0
],
p_2_arr
[
0
]))
)
-
np
.
log
(
_exp_sim
(
p_1_arr
[
0
],
p_2_arr
[
0
])
/
(
_exp_sim
(
p_2_arr
[
0
],
p_2_arr
[
1
])
+
_exp_sim
(
p_2_arr
[
0
],
p_1_arr
[
1
])
+
_exp_sim
(
p_1_arr
[
0
],
p_2_arr
[
0
]))
)
l22
=
-
np
.
log
(
_exp_sim
(
p_1_arr
[
1
],
p_2_arr
[
1
])
/
(
_exp_sim
(
p_1_arr
[
1
],
p_1_arr
[
0
])
+
_exp_sim
(
p_1_arr
[
1
],
p_2_arr
[
0
])
+
_exp_sim
(
p_1_arr
[
1
],
p_2_arr
[
1
]))
)
-
np
.
log
(
_exp_sim
(
p_1_arr
[
1
],
p_2_arr
[
1
])
/
(
_exp_sim
(
p_2_arr
[
1
],
p_2_arr
[
0
])
+
_exp_sim
(
p_2_arr
[
1
],
p_1_arr
[
0
])
+
_exp_sim
(
p_1_arr
[
1
],
p_2_arr
[
1
]))
)
exp_contrastive_loss
=
(
l11
+
l22
)
/
2.0
self
.
assertAlmostEqual
(
comp_contrastive_loss
[
0
].
numpy
(),
exp_contrastive_loss
,
places
=
5
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/simclr/modeling/layers/nn_blocks.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for simclr neural networks."""
from
typing
import
Text
,
Optional
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
regularizers
=
tf
.
keras
.
regularizers
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'simclr'
)
class
DenseBN
(
tf
.
keras
.
layers
.
Layer
):
"""Modified Dense layer to help build simclr system.
The layer is a standards combination of Dense, BatchNorm and Activation.
"""
def
__init__
(
self
,
output_dim
:
int
,
use_bias
:
bool
=
True
,
use_normalization
:
bool
=
False
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
activation
:
Optional
[
Text
]
=
'relu'
,
kernel_initializer
:
Text
=
'VarianceScaling'
,
kernel_regularizer
:
Optional
[
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
regularizers
.
Regularizer
]
=
None
,
name
=
'linear_layer'
,
**
kwargs
):
"""Customized Dense layer.
Args:
output_dim: `int` size of output dimension.
use_bias: if True, use biase in the dense layer.
use_normalization: if True, use batch normalization.
use_sync_bn: if True, use synchronized batch normalization.
norm_momentum: `float` normalization momentum for the moving average.
norm_epsilon: `float` small float added to variance to avoid dividing by
zero.
activation: `str` name of the activation function.
kernel_initializer: kernel_initializer for convolutional layers.
kernel_regularizer: tf.keras.regularizers.Regularizer object for Conv2D.
Default to None.
bias_regularizer: tf.keras.regularizers.Regularizer object for Conv2d.
Default to None.
name: `str`, name of the layer.
**kwargs: keyword arguments to be passed.
"""
# Note: use_bias is ignored for the dense layer when use_bn=True.
# However, it is still used for batch norm.
super
(
DenseBN
,
self
).
__init__
(
**
kwargs
)
self
.
_output_dim
=
output_dim
self
.
_use_bias
=
use_bias
self
.
_use_normalization
=
use_normalization
self
.
_use_sync_bn
=
use_sync_bn
self
.
_norm_momentum
=
norm_momentum
self
.
_norm_epsilon
=
norm_epsilon
self
.
_activation
=
activation
self
.
_kernel_initializer
=
kernel_initializer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_bias_regularizer
=
bias_regularizer
self
.
_name
=
name
if
use_sync_bn
:
self
.
_norm
=
tf
.
keras
.
layers
.
experimental
.
SyncBatchNormalization
else
:
self
.
_norm
=
tf
.
keras
.
layers
.
BatchNormalization
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
self
.
_bn_axis
=
-
1
else
:
self
.
_bn_axis
=
1
if
activation
:
self
.
_activation_fn
=
tf_utils
.
get_activation
(
activation
)
else
:
self
.
_activation_fn
=
None
def
get_config
(
self
):
config
=
{
'output_dim'
:
self
.
_output_dim
,
'use_bias'
:
self
.
_use_bias
,
'activation'
:
self
.
_activation
,
'use_sync_bn'
:
self
.
_use_sync_bn
,
'use_normalization'
:
self
.
_use_normalization
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
}
base_config
=
super
(
DenseBN
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
build
(
self
,
input_shape
):
self
.
_dense0
=
tf
.
keras
.
layers
.
Dense
(
self
.
_output_dim
,
kernel_initializer
=
self
.
_kernel_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
use_bias
=
self
.
_use_bias
and
not
self
.
_use_normalization
)
if
self
.
_use_normalization
:
self
.
_norm0
=
self
.
_norm
(
axis
=
self
.
_bn_axis
,
momentum
=
self
.
_norm_momentum
,
epsilon
=
self
.
_norm_epsilon
,
center
=
self
.
_use_bias
,
scale
=
True
)
super
(
DenseBN
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
assert
inputs
.
shape
.
ndims
==
2
,
inputs
.
shape
x
=
self
.
_dense0
(
inputs
)
if
self
.
_use_normalization
:
x
=
self
.
_norm0
(
x
)
if
self
.
_activation
:
x
=
self
.
_activation_fn
(
x
)
return
x
official/vision/beta/projects/simclr/modeling/layers/nn_blocks_test.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.simclr.modeling.layers
import
nn_blocks
class
DenseBNTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
parameters
(
(
64
,
True
,
True
),
(
64
,
True
,
False
),
(
64
,
False
,
True
),
)
def
test_pass_through
(
self
,
output_dim
,
use_bias
,
use_normalization
):
test_layer
=
nn_blocks
.
DenseBN
(
output_dim
=
output_dim
,
use_bias
=
use_bias
,
use_normalization
=
use_normalization
)
x
=
tf
.
keras
.
Input
(
shape
=
(
64
,))
out_x
=
test_layer
(
x
)
self
.
assertAllEqual
(
out_x
.
shape
.
as_list
(),
[
None
,
output_dim
])
# kernel of the dense layer
train_var_len
=
1
if
use_normalization
:
if
use_bias
:
# batch norm introduce two trainable variables
train_var_len
+=
2
else
:
# center is set to False if not use bias
train_var_len
+=
1
else
:
if
use_bias
:
# bias of dense layer
train_var_len
+=
1
self
.
assertLen
(
test_layer
.
trainable_variables
,
train_var_len
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment