Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aba78478
Commit
aba78478
authored
Mar 29, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 365713370
parent
f3f3ec34
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
986 additions
and
0 deletions
+986
-0
official/vision/beta/projects/simclr/modeling/simclr_model.py
...cial/vision/beta/projects/simclr/modeling/simclr_model.py
+177
-0
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
...vision/beta/projects/simclr/modeling/simclr_model_test.py
+87
-0
official/vision/beta/projects/simclr/tasks/simclr.py
official/vision/beta/projects/simclr/tasks/simclr.py
+640
-0
official/vision/beta/projects/simclr/train.py
official/vision/beta/projects/simclr/train.py
+82
-0
No files found.
official/vision/beta/projects/simclr/modeling/simclr_model.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models."""
from
typing
import
Optional
from
absl
import
logging
import
tensorflow
as
tf
layers
=
tf
.
keras
.
layers
PRETRAIN
=
'pretrain'
FINETUNE
=
'finetune'
PROJECTION_OUTPUT_KEY
=
'projection_outputs'
SUPERVISED_OUTPUT_KEY
=
'supervised_outputs'
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'simclr'
)
class
SimCLRModel
(
tf
.
keras
.
Model
):
"""A classification model based on SimCLR framework."""
def
__init__
(
self
,
backbone
:
tf
.
keras
.
models
.
Model
,
projection_head
:
tf
.
keras
.
layers
.
Layer
,
supervised_head
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
mode
:
str
=
PRETRAIN
,
backbone_trainable
:
bool
=
True
,
**
kwargs
):
"""A classification model based on SimCLR framework.
Args:
backbone: a backbone network.
projection_head: a projection head network.
supervised_head: a head network for supervised learning, e.g.
classification head.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
mode: `str` indicates mode of training to be executed.
backbone_trainable: `bool` whether the backbone is trainable or not.
**kwargs: keyword arguments to be passed.
"""
super
(
SimCLRModel
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'projection_head'
:
projection_head
,
'supervised_head'
:
supervised_head
,
'input_specs'
:
input_specs
,
'mode'
:
mode
,
'backbone_trainable'
:
backbone_trainable
,
}
self
.
_input_specs
=
input_specs
self
.
_backbone
=
backbone
self
.
_projection_head
=
projection_head
self
.
_supervised_head
=
supervised_head
self
.
_mode
=
mode
self
.
_backbone_trainable
=
backbone_trainable
# Set whether the backbone is trainable
self
.
_backbone
.
trainable
=
backbone_trainable
def
call
(
self
,
inputs
,
training
=
None
,
**
kwargs
):
model_outputs
=
{}
if
training
and
self
.
_mode
==
PRETRAIN
:
num_transforms
=
2
else
:
num_transforms
=
1
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list
=
tf
.
split
(
inputs
,
num_or_size_splits
=
num_transforms
,
axis
=-
1
)
# (num_transforms * bsz, h, w, c)
features
=
tf
.
concat
(
features_list
,
0
)
# Base network forward pass.
endpoints
=
self
.
_backbone
(
features
,
training
=
training
)
features
=
endpoints
[
max
(
endpoints
.
keys
())]
projection_inputs
=
layers
.
GlobalAveragePooling2D
()(
features
)
# Add heads.
projection_outputs
,
supervised_inputs
=
self
.
_projection_head
(
projection_inputs
,
training
)
if
self
.
_supervised_head
is
not
None
:
if
self
.
_mode
==
PRETRAIN
:
logging
.
info
(
'Ignoring gradient from supervised outputs !'
)
# When performing pretraining and supervised_head together, we do not
# want information from supervised evaluation flowing back into
# pretraining network. So we put a stop_gradient.
supervised_outputs
=
self
.
_supervised_head
(
tf
.
stop_gradient
(
supervised_inputs
),
training
)
else
:
supervised_outputs
=
self
.
_supervised_head
(
supervised_inputs
,
training
)
else
:
supervised_outputs
=
None
model_outputs
.
update
({
PROJECTION_OUTPUT_KEY
:
projection_outputs
,
SUPERVISED_OUTPUT_KEY
:
supervised_outputs
})
return
model_outputs
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
if
self
.
_supervised_head
is
not
None
:
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
,
supervised_head
=
self
.
supervised_head
)
else
:
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
)
return
items
@
property
def
backbone
(
self
):
return
self
.
_backbone
@
property
def
projection_head
(
self
):
return
self
.
_projection_head
@
property
def
supervised_head
(
self
):
return
self
.
_supervised_head
@
property
def
mode
(
self
):
return
self
.
_mode
@
mode
.
setter
def
mode
(
self
,
value
):
self
.
_mode
=
value
@
property
def
backbone_trainable
(
self
):
return
self
.
_backbone_trainable
@
backbone_trainable
.
setter
def
backbone_trainable
(
self
,
value
):
self
.
_backbone_trainable
=
value
self
.
_backbone
.
trainable
=
value
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
class
SimCLRModelTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
128
,
3
,
0
),
(
128
,
3
,
1
),
(
128
,
1
,
0
),
(
128
,
1
,
1
),
)
def
test_model_creation
(
self
,
project_dim
,
num_proj_layers
,
ft_proj_idx
):
input_size
=
224
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
,
input_size
,
3
])
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
backbones
.
ResNet
(
model_id
=
50
,
activation
=
'relu'
,
input_specs
=
input_specs
)
projection_head
=
simclr_head
.
ProjectionHead
(
proj_output_dim
=
project_dim
,
num_proj_layers
=
num_proj_layers
,
ft_proj_idx
=
ft_proj_idx
)
num_classes
=
10
supervised_head
=
simclr_head
.
ClassificationHead
(
num_classes
=
10
)
model
=
simclr_model
.
SimCLRModel
(
input_specs
=
input_specs
,
backbone
=
backbone
,
projection_head
=
projection_head
,
supervised_head
=
supervised_head
,
mode
=
simclr_model
.
PRETRAIN
)
outputs
=
model
(
inputs
)
projection_outputs
=
outputs
[
simclr_model
.
PROJECTION_OUTPUT_KEY
]
supervised_outputs
=
outputs
[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
self
.
assertAllEqual
(
projection_outputs
.
shape
.
as_list
(),
[
2
,
project_dim
])
self
.
assertAllEqual
([
2
,
num_classes
],
supervised_outputs
.
numpy
().
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/simclr/tasks/simclr.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image SimCLR task definition.
SimCLR training two different modes:
- pretrain
- fine-tuning
For the above two different modes, the following components are different in
the task definition:
- training data format
- training loss
- projection_head and/or supervised_head
"""
from
typing
import
Dict
,
Optional
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
input_reader
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling
import
performance
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.projects.simclr.configs
import
simclr
as
exp_cfg
from
official.vision.beta.projects.simclr.dataloaders
import
simclr_input
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.losses
import
contrastive_losses
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
@
task_factory
.
register_task_cls
(
exp_cfg
.
SimCLRPretrainTask
)
class
SimCLRPretrainTask
(
base_task
.
Task
):
"""A task for image classification."""
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.'
)
opt_factory
=
optimization
.
OptimizerFactory
(
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
and
runtime_config
.
loss_scale
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
'float16'
,
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
def
build_model
(
self
):
model_config
=
self
.
task_config
.
model
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
model_config
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
# Build backbone
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
# Build projection head
norm_activation_config
=
model_config
.
norm_activation
projection_head_config
=
model_config
.
projection_head
projection_head
=
simclr_head
.
ProjectionHead
(
proj_output_dim
=
projection_head_config
.
proj_output_dim
,
num_proj_layers
=
projection_head_config
.
num_proj_layers
,
ft_proj_idx
=
projection_head_config
.
ft_proj_idx
,
kernel_regularizer
=
l2_regularizer
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
)
# Build supervised head
supervised_head_config
=
model_config
.
supervised_head
if
supervised_head_config
:
if
supervised_head_config
.
zero_init
:
s_kernel_initializer
=
'zeros'
else
:
s_kernel_initializer
=
'random_uniform'
supervised_head
=
simclr_head
.
ClassificationHead
(
num_classes
=
supervised_head_config
.
num_classes
,
kernel_initializer
=
s_kernel_initializer
,
kernel_regularizer
=
l2_regularizer
)
else
:
supervised_head
=
None
model
=
simclr_model
.
SimCLRModel
(
input_specs
=
input_specs
,
backbone
=
backbone
,
projection_head
=
projection_head
,
supervised_head
=
supervised_head
,
mode
=
model_config
.
mode
,
backbone_trainable
=
model_config
.
backbone_trainable
)
logging
.
info
(
model
.
get_config
())
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
=
None
):
input_size
=
self
.
task_config
.
model
.
input_size
if
params
.
tfds_name
:
decoder
=
simclr_input
.
TFDSDecoder
(
params
.
decoder
.
decode_label
)
else
:
decoder
=
simclr_input
.
Decoder
(
params
.
decoder
.
decode_label
)
parser
=
simclr_input
.
Parser
(
output_size
=
input_size
[:
2
],
aug_rand_crop
=
params
.
parser
.
aug_rand_crop
,
aug_rand_hflip
=
params
.
parser
.
aug_rand_hflip
,
aug_color_distort
=
params
.
parser
.
aug_color_distort
,
aug_color_jitter_strength
=
params
.
parser
.
aug_color_jitter_strength
,
aug_color_jitter_impl
=
params
.
parser
.
aug_color_jitter_impl
,
aug_rand_blur
=
params
.
parser
.
aug_rand_blur
,
parse_label
=
params
.
parser
.
parse_label
,
test_crop
=
params
.
parser
.
test_crop
,
mode
=
params
.
parser
.
mode
,
dtype
=
params
.
dtype
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
Dict
[
str
,
tf
.
Tensor
]:
# Compute contrastive relative loss
con_losses_obj
=
contrastive_losses
.
ContrastiveLoss
(
projection_norm
=
self
.
task_config
.
loss
.
projection_norm
,
temperature
=
self
.
task_config
.
loss
.
temperature
)
# The projection outputs from model has the size of
# (2 * bsz, project_dim)
projection_outputs
=
model_outputs
[
simclr_model
.
PROJECTION_OUTPUT_KEY
]
projection1
,
projection2
=
tf
.
split
(
projection_outputs
,
2
,
0
)
contrast_loss
,
(
contrast_logits
,
contrast_labels
)
=
con_losses_obj
(
projection1
=
projection1
,
projection2
=
projection2
)
contrast_accuracy
=
tf
.
equal
(
tf
.
argmax
(
contrast_labels
,
axis
=
1
),
tf
.
argmax
(
contrast_logits
,
axis
=
1
))
contrast_accuracy
=
tf
.
reduce_mean
(
tf
.
cast
(
contrast_accuracy
,
tf
.
float32
))
contrast_prob
=
tf
.
nn
.
softmax
(
contrast_logits
)
contrast_entropy
=
-
tf
.
reduce_mean
(
tf
.
reduce_sum
(
contrast_prob
*
tf
.
math
.
log
(
contrast_prob
+
1e-8
),
-
1
))
model_loss
=
contrast_loss
losses
=
{
'contrast_loss'
:
contrast_loss
,
'contrast_accuracy'
:
contrast_accuracy
,
'contrast_entropy'
:
contrast_entropy
}
if
self
.
task_config
.
model
.
supervised_head
is
not
None
:
outputs
=
model_outputs
[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
labels
=
tf
.
concat
([
labels
,
labels
],
0
)
if
self
.
task_config
.
evaluation
.
one_hot
:
sup_loss
=
tf
.
keras
.
losses
.
CategoricalCrossentropy
(
from_logits
=
True
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
NONE
)(
labels
,
outputs
)
else
:
sup_loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
,
reduction
=
tf
.
keras
.
losses
.
Reduction
.
NONE
)(
labels
,
outputs
)
sup_loss
=
tf
.
reduce_mean
(
sup_loss
)
label_acc
=
tf
.
equal
(
tf
.
argmax
(
labels
,
axis
=
1
),
tf
.
argmax
(
outputs
,
axis
=
1
))
label_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
label_acc
,
tf
.
float32
))
model_loss
=
contrast_loss
+
sup_loss
losses
.
update
({
'accuracy'
:
label_acc
,
'supervised_loss'
:
sup_loss
,
})
total_loss
=
model_loss
if
aux_losses
:
reg_loss
=
tf
.
reduce_sum
(
aux_losses
)
total_loss
=
model_loss
+
reg_loss
losses
[
'total_loss'
]
=
total_loss
return
losses
def
build_metrics
(
self
,
training
=
True
):
if
training
:
metrics
=
[]
metric_names
=
[
'total_loss'
,
'contrast_loss'
,
'contrast_accuracy'
,
'contrast_entropy'
]
if
self
.
task_config
.
model
.
supervised_head
:
metric_names
.
extend
([
'supervised_loss'
,
'accuracy'
])
for
name
in
metric_names
:
metrics
.
append
(
tf
.
keras
.
metrics
.
Mean
(
name
,
dtype
=
tf
.
float32
))
else
:
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
evaluation
.
one_hot
:
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
features
,
labels
=
inputs
if
(
self
.
task_config
.
model
.
supervised_head
is
not
None
and
self
.
task_config
.
evaluation
.
one_hot
):
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
losses
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
scaled_loss
=
losses
[
'total_loss'
]
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
logging
.
info
(
'Trainable variables:'
)
for
var
in
tvars
:
logging
.
info
(
var
.
name
)
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient when LossScaleOptimizer is used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
losses
[
'total_loss'
]}
for
m
in
metrics
:
m
.
update_state
(
losses
[
m
.
name
])
logs
.
update
({
m
.
name
:
m
.
result
()})
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
if
self
.
task_config
.
model
.
supervised_head
is
None
:
assert
'Skipping eval during pretraining without supervised head.'
features
,
labels
=
inputs
if
self
.
task_config
.
evaluation
.
one_hot
:
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
outputs
=
model
(
features
,
training
=
False
)[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
logs
=
{
self
.
loss
:
0
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
@
task_factory
.
register_task_cls
(
exp_cfg
.
SimCLRFinetuneTask
)
class
SimCLRFinetuneTask
(
base_task
.
Task
):
"""A task for image classification."""
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
Args:
optimizer_config: the parameters of the Optimization settings.
runtime_config: the parameters of the runtime.
Returns:
A tf.optimizers.Optimizer object.
"""
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.'
)
opt_factory
=
optimization
.
OptimizerFactory
(
optimizer_config
)
optimizer
=
opt_factory
.
build_optimizer
(
opt_factory
.
build_learning_rate
())
# Configuring optimizer when loss_scale is set in runtime config. This helps
# avoiding overflow/underflow for float16 computations.
if
runtime_config
and
runtime_config
.
loss_scale
:
optimizer
=
performance
.
configure_optimizer
(
optimizer
,
use_float16
=
runtime_config
.
mixed_precision_dtype
==
'float16'
,
loss_scale
=
runtime_config
.
loss_scale
)
return
optimizer
def
build_model
(
self
):
model_config
=
self
.
task_config
.
model
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
model_config
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
model_config
=
model_config
,
l2_regularizer
=
l2_regularizer
)
norm_activation_config
=
model_config
.
norm_activation
projection_head_config
=
model_config
.
projection_head
projection_head
=
simclr_head
.
ProjectionHead
(
proj_output_dim
=
projection_head_config
.
proj_output_dim
,
num_proj_layers
=
projection_head_config
.
num_proj_layers
,
ft_proj_idx
=
projection_head_config
.
ft_proj_idx
,
kernel_regularizer
=
l2_regularizer
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
)
supervised_head_config
=
model_config
.
supervised_head
if
supervised_head_config
.
zero_init
:
s_kernel_initializer
=
'zeros'
else
:
s_kernel_initializer
=
'random_uniform'
supervised_head
=
simclr_head
.
ClassificationHead
(
num_classes
=
supervised_head_config
.
num_classes
,
kernel_initializer
=
s_kernel_initializer
,
kernel_regularizer
=
l2_regularizer
)
model
=
simclr_model
.
SimCLRModel
(
input_specs
=
input_specs
,
backbone
=
backbone
,
projection_head
=
projection_head
,
supervised_head
=
supervised_head
,
mode
=
model_config
.
mode
,
backbone_trainable
=
model_config
.
backbone_trainable
)
logging
.
info
(
model
.
get_config
())
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Loading pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
ckpt_dir_or_file
=
self
.
task_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
# Restoring checkpoint.
if
self
.
task_config
.
init_checkpoint_modules
==
'all'
:
ckpt
=
tf
.
train
.
Checkpoint
(
**
model
.
checkpoint_items
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone_projection'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
,
projection_head
=
model
.
projection_head
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
# If the checkpoint is from pretraining, reset the following parameters
model
.
backbone_trainable
=
self
.
task_config
.
model
.
backbone_trainable
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
=
None
):
input_size
=
self
.
task_config
.
model
.
input_size
if
params
.
tfds_name
:
decoder
=
simclr_input
.
TFDSDecoder
(
params
.
decoder
.
decode_label
)
else
:
decoder
=
simclr_input
.
Decoder
(
params
.
decoder
.
decode_label
)
parser
=
simclr_input
.
Parser
(
output_size
=
input_size
[:
2
],
parse_label
=
params
.
parser
.
parse_label
,
test_crop
=
params
.
parser
.
test_crop
,
mode
=
params
.
parser
.
mode
,
dtype
=
params
.
dtype
)
reader
=
input_reader
.
InputReader
(
params
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
dataset
=
reader
.
read
(
input_context
=
input_context
)
return
dataset
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
losses_config
=
self
.
task_config
.
loss
if
losses_config
.
one_hot
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
total_loss
=
tf_utils
.
safe_mean
(
total_loss
)
if
aux_losses
:
total_loss
+=
tf
.
add_n
(
aux_losses
)
return
total_loss
def
build_metrics
(
self
,
training
=
True
):
"""Gets streaming metrics for training/validation."""
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
evaluation
.
one_hot
:
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
if
self
.
task_config
.
loss
.
one_hot
:
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
num_replicas
=
tf
.
distribute
.
get_strategy
().
num_replicas_in_sync
with
tf
.
GradientTape
()
as
tape
:
outputs
=
model
(
features
,
training
=
True
)[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
logging
.
info
(
'Trainable variables:'
)
for
var
in
tvars
:
logging
.
info
(
var
.
name
)
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
if
self
.
task_config
.
loss
.
one_hot
:
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
self
.
process_metrics
(
metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
metrics
})
elif
model
.
compiled_metrics
:
self
.
process_compiled_metrics
(
model
.
compiled_metrics
,
labels
,
outputs
)
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
official/vision/beta/projects/simclr/train.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.vision.beta.projects.simclr.common
import
registry_imports
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
print
(
FLAGS
.
experiment
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment