Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
59e1ab8a
Commit
59e1ab8a
authored
Aug 03, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 388514034
parent
08b68031
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
69 additions
and
297 deletions
+69
-297
official/vision/beta/projects/simclr/common/registry_imports.py
...al/vision/beta/projects/simclr/common/registry_imports.py
+0
-14
official/vision/beta/projects/simclr/configs/simclr.py
official/vision/beta/projects/simclr/configs/simclr.py
+18
-34
official/vision/beta/projects/simclr/configs/simclr_test.py
official/vision/beta/projects/simclr/configs/simclr_test.py
+1
-17
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
...vision/beta/projects/simclr/dataloaders/preprocess_ops.py
+0
-14
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
...l/vision/beta/projects/simclr/dataloaders/simclr_input.py
+0
-14
official/vision/beta/projects/simclr/heads/simclr_head.py
official/vision/beta/projects/simclr/heads/simclr_head.py
+1
-15
official/vision/beta/projects/simclr/heads/simclr_head_test.py
...ial/vision/beta/projects/simclr/heads/simclr_head_test.py
+0
-16
official/vision/beta/projects/simclr/losses/contrastive_losses.py
.../vision/beta/projects/simclr/losses/contrastive_losses.py
+0
-15
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
...on/beta/projects/simclr/losses/contrastive_losses_test.py
+0
-16
official/vision/beta/projects/simclr/modeling/layers/nn_blocks.py
.../vision/beta/projects/simclr/modeling/layers/nn_blocks.py
+0
-16
official/vision/beta/projects/simclr/modeling/layers/nn_blocks_test.py
...on/beta/projects/simclr/modeling/layers/nn_blocks_test.py
+0
-16
official/vision/beta/projects/simclr/modeling/simclr_model.py
...cial/vision/beta/projects/simclr/modeling/simclr_model.py
+5
-20
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
...vision/beta/projects/simclr/modeling/simclr_model_test.py
+1
-16
official/vision/beta/projects/simclr/tasks/simclr.py
official/vision/beta/projects/simclr/tasks/simclr.py
+42
-58
official/vision/beta/projects/simclr/train.py
official/vision/beta/projects/simclr/train.py
+1
-16
No files found.
official/vision/beta/projects/simclr/common/registry_imports.py
View file @
59e1ab8a
...
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
# pylint: disable=unused-import
...
...
official/vision/beta/projects/simclr/configs/simclr.py
View file @
59e1ab8a
...
...
@@ -12,26 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations."""
import
os
import
dataclasses
import
os.path
from
typing
import
List
,
Optional
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
...
...
@@ -115,9 +100,7 @@ class SimCLRModel(hyperparams.Config):
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
projection_head
:
ProjectionHead
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
supervised_head
:
SupervisedHead
=
SupervisedHead
(
num_classes
=
1001
)
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
...
...
@@ -201,9 +184,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
),
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
True
)),
...
...
@@ -233,10 +214,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
'optimizer'
:
{
'type'
:
'lars'
,
'lars'
:
{
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.000001
,
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.000001
,
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
'batch_normalization'
,
'bias'
]
}
},
'learning_rate'
:
{
...
...
@@ -278,11 +262,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
),
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
),
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)),
loss
=
ClassificationLosses
(),
...
...
@@ -311,10 +292,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
'optimizer'
:
{
'type'
:
'lars'
,
'lars'
:
{
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.0
,
'momentum'
:
0.9
,
'weight_decay_rate'
:
0.0
,
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
'batch_normalization'
,
'bias'
]
}
},
'learning_rate'
:
{
...
...
official/vision/beta/projects/simclr/configs/simclr_test.py
View file @
59e1ab8a
...
...
@@ -12,23 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
"""Tests for SimCLR config."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
View file @
59e1ab8a
...
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops."""
import
functools
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
View file @
59e1ab8a
...
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR.
For pre-training:
...
...
official/vision/beta/projects/simclr/heads/simclr_head.py
View file @
59e1ab8a
...
...
@@ -12,21 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
"""SimCLR prediction heads."""
from
typing
import
Text
,
Optional
...
...
official/vision/beta/projects/simclr/heads/simclr_head_test.py
View file @
59e1ab8a
...
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
official/vision/beta/projects/simclr/losses/contrastive_losses.py
View file @
59e1ab8a
...
...
@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions."""
import
functools
...
...
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
View file @
59e1ab8a
...
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
official/vision/beta/projects/simclr/modeling/layers/nn_blocks.py
View file @
59e1ab8a
...
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for simclr neural networks."""
from
typing
import
Text
,
Optional
...
...
official/vision/beta/projects/simclr/modeling/layers/nn_blocks_test.py
View file @
59e1ab8a
...
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/modeling/simclr_model.py
View file @
59e1ab8a
...
...
@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models."""
from
typing
import
Optional
from
absl
import
logging
...
...
@@ -133,12 +118,12 @@ class SimCLRModel(tf.keras.Model):
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
if
self
.
_supervised_head
is
not
None
:
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
,
supervised_head
=
self
.
supervised_head
)
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
,
supervised_head
=
self
.
supervised_head
)
else
:
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
)
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
)
return
items
@
property
...
...
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
View file @
59e1ab8a
...
...
@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for SimCLR model."""
from
absl.testing
import
parameterized
import
numpy
as
np
...
...
official/vision/beta/projects/simclr/tasks/simclr.py
View file @
59e1ab8a
...
...
@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image SimCLR task definition.
SimCLR training two different modes:
...
...
@@ -39,7 +24,6 @@ the task definition:
- training loss
- projection_head and/or supervised_head
"""
from
typing
import
Dict
,
Optional
from
absl
import
logging
...
...
@@ -67,7 +51,8 @@ RuntimeConfig = config_definitions.RuntimeConfig
class
SimCLRPretrainTask
(
base_task
.
Task
):
"""A task for image classification."""
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
...
...
@@ -78,8 +63,8 @@ class SimCLRPretrainTask(base_task.Task):
Returns:
A tf.optimizers.Optimizer object.
"""
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.'
)
...
...
@@ -97,15 +82,16 @@ class SimCLRPretrainTask(base_task.Task):
def
build_model
(
self
):
model_config
=
self
.
task_config
.
model
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
model_config
.
input_size
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
model_config
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
# Build backbone
backbone
=
backbones
.
factory
.
build_backbone
(
...
...
@@ -220,8 +206,7 @@ class SimCLRPretrainTask(base_task.Task):
projection_outputs
=
model_outputs
[
simclr_model
.
PROJECTION_OUTPUT_KEY
]
projection1
,
projection2
=
tf
.
split
(
projection_outputs
,
2
,
0
)
contrast_loss
,
(
contrast_logits
,
contrast_labels
)
=
con_losses_obj
(
projection1
=
projection1
,
projection2
=
projection2
)
projection1
=
projection1
,
projection2
=
projection2
)
contrast_accuracy
=
tf
.
equal
(
tf
.
argmax
(
contrast_labels
,
axis
=
1
),
tf
.
argmax
(
contrast_logits
,
axis
=
1
))
...
...
@@ -253,8 +238,8 @@ class SimCLRPretrainTask(base_task.Task):
outputs
)
sup_loss
=
tf
.
reduce_mean
(
sup_loss
)
label_acc
=
tf
.
equal
(
tf
.
argmax
(
labels
,
axis
=
1
),
tf
.
argmax
(
outputs
,
axis
=
1
))
label_acc
=
tf
.
equal
(
tf
.
argmax
(
labels
,
axis
=
1
),
tf
.
argmax
(
outputs
,
axis
=
1
))
label_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
label_acc
,
tf
.
float32
))
model_loss
=
contrast_loss
+
sup_loss
...
...
@@ -278,10 +263,7 @@ class SimCLRPretrainTask(base_task.Task):
if
training
:
metrics
=
[]
metric_names
=
[
'total_loss'
,
'contrast_loss'
,
'contrast_accuracy'
,
'contrast_entropy'
'total_loss'
,
'contrast_loss'
,
'contrast_accuracy'
,
'contrast_entropy'
]
if
self
.
task_config
.
model
.
supervised_head
:
metric_names
.
extend
([
'supervised_loss'
,
'accuracy'
])
...
...
@@ -293,18 +275,20 @@ class SimCLRPretrainTask(base_task.Task):
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
features
,
labels
=
inputs
if
(
self
.
task_config
.
model
.
supervised_head
is
not
None
and
self
.
task_config
.
evaluation
.
one_hot
):
if
(
self
.
task_config
.
model
.
supervised_head
is
not
None
and
self
.
task_config
.
evaluation
.
one_hot
):
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
...
...
@@ -313,8 +297,7 @@ class SimCLRPretrainTask(base_task.Task):
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
losses
=
self
.
build_losses
(
...
...
@@ -373,7 +356,8 @@ class SimCLRPretrainTask(base_task.Task):
class
SimCLRFinetuneTask
(
base_task
.
Task
):
"""A task for image classification."""
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
...
...
@@ -384,8 +368,8 @@ class SimCLRFinetuneTask(base_task.Task):
Returns:
A tf.optimizers.Optimizer object.
"""
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.'
)
...
...
@@ -403,15 +387,16 @@ class SimCLRFinetuneTask(base_task.Task):
def
build_model
(
self
):
model_config
=
self
.
task_config
.
model
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
model_config
.
input_size
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
model_config
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
...
...
@@ -467,8 +452,8 @@ class SimCLRFinetuneTask(base_task.Task):
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone_projection'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
,
projection_head
=
model
.
projection_head
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
,
projection_head
=
model
.
projection_head
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
...
...
@@ -542,12 +527,14 @@ class SimCLRFinetuneTask(base_task.Task):
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
else
:
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
...
...
@@ -577,16 +564,14 @@ class SimCLRFinetuneTask(base_task.Task):
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
...
...
@@ -596,8 +581,7 @@ class SimCLRFinetuneTask(base_task.Task):
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
...
...
@@ -626,11 +610,11 @@ class SimCLRFinetuneTask(base_task.Task):
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
outputs
=
self
.
inference_step
(
features
,
model
)[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
...
...
official/vision/beta/projects/simclr/train.py
View file @
59e1ab8a
...
...
@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
"""TensorFlow Model Garden Vision SimCLR trainer."""
from
absl
import
app
from
absl
import
flags
import
gin
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment