Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
59e1ab8a
Commit
59e1ab8a
authored
Aug 03, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 388514034
parent
08b68031
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
69 additions
and
297 deletions
+69
-297
official/vision/beta/projects/simclr/common/registry_imports.py
...al/vision/beta/projects/simclr/common/registry_imports.py
+0
-14
official/vision/beta/projects/simclr/configs/simclr.py
official/vision/beta/projects/simclr/configs/simclr.py
+18
-34
official/vision/beta/projects/simclr/configs/simclr_test.py
official/vision/beta/projects/simclr/configs/simclr_test.py
+1
-17
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
...vision/beta/projects/simclr/dataloaders/preprocess_ops.py
+0
-14
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
...l/vision/beta/projects/simclr/dataloaders/simclr_input.py
+0
-14
official/vision/beta/projects/simclr/heads/simclr_head.py
official/vision/beta/projects/simclr/heads/simclr_head.py
+1
-15
official/vision/beta/projects/simclr/heads/simclr_head_test.py
...ial/vision/beta/projects/simclr/heads/simclr_head_test.py
+0
-16
official/vision/beta/projects/simclr/losses/contrastive_losses.py
.../vision/beta/projects/simclr/losses/contrastive_losses.py
+0
-15
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
...on/beta/projects/simclr/losses/contrastive_losses_test.py
+0
-16
official/vision/beta/projects/simclr/modeling/layers/nn_blocks.py
.../vision/beta/projects/simclr/modeling/layers/nn_blocks.py
+0
-16
official/vision/beta/projects/simclr/modeling/layers/nn_blocks_test.py
...on/beta/projects/simclr/modeling/layers/nn_blocks_test.py
+0
-16
official/vision/beta/projects/simclr/modeling/simclr_model.py
...cial/vision/beta/projects/simclr/modeling/simclr_model.py
+5
-20
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
...vision/beta/projects/simclr/modeling/simclr_model_test.py
+1
-16
official/vision/beta/projects/simclr/tasks/simclr.py
official/vision/beta/projects/simclr/tasks/simclr.py
+42
-58
official/vision/beta/projects/simclr/train.py
official/vision/beta/projects/simclr/train.py
+1
-16
No files found.
official/vision/beta/projects/simclr/common/registry_imports.py
View file @
59e1ab8a
...
@@ -12,20 +12,6 @@
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""All necessary imports for registration."""
"""All necessary imports for registration."""
# pylint: disable=unused-import
# pylint: disable=unused-import
...
...
official/vision/beta/projects/simclr/configs/simclr.py
View file @
59e1ab8a
...
@@ -12,26 +12,11 @@
...
@@ -12,26 +12,11 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""SimCLR configurations."""
"""SimCLR configurations."""
import
os
import
dataclasses
import
os.path
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
...
@@ -115,9 +100,7 @@ class SimCLRModel(hyperparams.Config):
...
@@ -115,9 +100,7 @@ class SimCLRModel(hyperparams.Config):
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
projection_head
:
ProjectionHead
=
ProjectionHead
(
projection_head
:
ProjectionHead
=
ProjectionHead
(
proj_output_dim
=
128
,
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
supervised_head
:
SupervisedHead
=
SupervisedHead
(
num_classes
=
1001
)
supervised_head
:
SupervisedHead
=
SupervisedHead
(
num_classes
=
1001
)
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
...
@@ -201,9 +184,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
...
@@ -201,9 +184,7 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
),
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
True
)),
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
True
)),
...
@@ -233,10 +214,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
...
@@ -233,10 +214,13 @@ def simclr_pretraining_imagenet() -> cfg.ExperimentConfig:
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'lars'
,
'type'
:
'lars'
,
'lars'
:
{
'lars'
:
{
'momentum'
:
0.9
,
'momentum'
:
'weight_decay_rate'
:
0.000001
,
0.9
,
'weight_decay_rate'
:
0.000001
,
'exclude_from_weight_decay'
:
[
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
'batch_normalization'
,
'bias'
]
}
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
...
@@ -278,11 +262,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
...
@@ -278,11 +262,8 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
(
model_id
=
50
)),
projection_head
=
ProjectionHead
(
projection_head
=
ProjectionHead
(
proj_output_dim
=
128
,
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
),
num_proj_layers
=
3
,
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
),
ft_proj_idx
=
1
),
supervised_head
=
SupervisedHead
(
num_classes
=
1001
,
zero_init
=
True
),
norm_activation
=
common
.
NormActivation
(
norm_activation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)),
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)),
loss
=
ClassificationLosses
(),
loss
=
ClassificationLosses
(),
...
@@ -311,10 +292,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
...
@@ -311,10 +292,13 @@ def simclr_finetuning_imagenet() -> cfg.ExperimentConfig:
'optimizer'
:
{
'optimizer'
:
{
'type'
:
'lars'
,
'type'
:
'lars'
,
'lars'
:
{
'lars'
:
{
'momentum'
:
0.9
,
'momentum'
:
'weight_decay_rate'
:
0.0
,
0.9
,
'weight_decay_rate'
:
0.0
,
'exclude_from_weight_decay'
:
[
'exclude_from_weight_decay'
:
[
'batch_normalization'
,
'bias'
]
'batch_normalization'
,
'bias'
]
}
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
...
...
official/vision/beta/projects/simclr/configs/simclr_test.py
View file @
59e1ab8a
...
@@ -12,23 +12,7 @@
...
@@ -12,23 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Tests for SimCLR config."""
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for simclr."""
# pylint: disable=unused-import
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/dataloaders/preprocess_ops.py
View file @
59e1ab8a
...
@@ -12,20 +12,6 @@
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Preprocessing ops."""
"""Preprocessing ops."""
import
functools
import
functools
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
View file @
59e1ab8a
...
@@ -12,20 +12,6 @@
...
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Data parser and processing for SimCLR.
"""Data parser and processing for SimCLR.
For pre-training:
For pre-training:
...
...
official/vision/beta/projects/simclr/heads/simclr_head.py
View file @
59e1ab8a
...
@@ -12,21 +12,7 @@
...
@@ -12,21 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
"""SimCLR prediction heads."""
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dense prediction heads."""
from
typing
import
Text
,
Optional
from
typing
import
Text
,
Optional
...
...
official/vision/beta/projects/simclr/heads/simclr_head_test.py
View file @
59e1ab8a
...
@@ -12,22 +12,6 @@
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
...
...
official/vision/beta/projects/simclr/losses/contrastive_losses.py
View file @
59e1ab8a
...
@@ -12,21 +12,6 @@
...
@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contrastive loss functions."""
"""Contrastive loss functions."""
import
functools
import
functools
...
...
official/vision/beta/projects/simclr/losses/contrastive_losses_test.py
View file @
59e1ab8a
...
@@ -12,22 +12,6 @@
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
...
...
official/vision/beta/projects/simclr/modeling/layers/nn_blocks.py
View file @
59e1ab8a
...
@@ -12,22 +12,6 @@
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains common building blocks for simclr neural networks."""
"""Contains common building blocks for simclr neural networks."""
from
typing
import
Text
,
Optional
from
typing
import
Text
,
Optional
...
...
official/vision/beta/projects/simclr/modeling/layers/nn_blocks_test.py
View file @
59e1ab8a
...
@@ -12,22 +12,6 @@
...
@@ -12,22 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
tensorflow
as
tf
import
tensorflow
as
tf
...
...
official/vision/beta/projects/simclr/modeling/simclr_model.py
View file @
59e1ab8a
...
@@ -12,22 +12,7 @@
...
@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models."""
"""Build simclr models."""
from
typing
import
Optional
from
typing
import
Optional
from
absl
import
logging
from
absl
import
logging
...
@@ -133,12 +118,12 @@ class SimCLRModel(tf.keras.Model):
...
@@ -133,12 +118,12 @@ class SimCLRModel(tf.keras.Model):
def
checkpoint_items
(
self
):
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
"""Returns a dictionary of items to be additionally checkpointed."""
if
self
.
_supervised_head
is
not
None
:
if
self
.
_supervised_head
is
not
None
:
items
=
dict
(
backbone
=
self
.
backbone
,
items
=
dict
(
projection_head
=
self
.
projection_head
,
backbone
=
self
.
backbone
,
supervised_head
=
self
.
supervised_head
)
projection_head
=
self
.
projection_head
,
supervised_head
=
self
.
supervised_head
)
else
:
else
:
items
=
dict
(
backbone
=
self
.
backbone
,
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
)
projection_head
=
self
.
projection_head
)
return
items
return
items
@
property
@
property
...
...
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
View file @
59e1ab8a
...
@@ -12,22 +12,7 @@
...
@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""Test for SimCLR model."""
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
...
...
official/vision/beta/projects/simclr/tasks/simclr.py
View file @
59e1ab8a
...
@@ -12,21 +12,6 @@
...
@@ -12,21 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Image SimCLR task definition.
"""Image SimCLR task definition.
SimCLR training two different modes:
SimCLR training two different modes:
...
@@ -39,7 +24,6 @@ the task definition:
...
@@ -39,7 +24,6 @@ the task definition:
- training loss
- training loss
- projection_head and/or supervised_head
- projection_head and/or supervised_head
"""
"""
from
typing
import
Dict
,
Optional
from
typing
import
Dict
,
Optional
from
absl
import
logging
from
absl
import
logging
...
@@ -67,7 +51,8 @@ RuntimeConfig = config_definitions.RuntimeConfig
...
@@ -67,7 +51,8 @@ RuntimeConfig = config_definitions.RuntimeConfig
class
SimCLRPretrainTask
(
base_task
.
Task
):
class
SimCLRPretrainTask
(
base_task
.
Task
):
"""A task for image classification."""
"""A task for image classification."""
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
"""Creates an TF optimizer from configurations.
...
@@ -78,8 +63,8 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -78,8 +63,8 @@ class SimCLRPretrainTask(base_task.Task):
Returns:
Returns:
A tf.optimizers.Optimizer object.
A tf.optimizers.Optimizer object.
"""
"""
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.'
)
'optimizer. Please set it to 0.'
)
...
@@ -97,15 +82,16 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -97,15 +82,16 @@ class SimCLRPretrainTask(base_task.Task):
def
build_model
(
self
):
def
build_model
(
self
):
model_config
=
self
.
task_config
.
model
model_config
=
self
.
task_config
.
model
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
shape
=
[
None
]
+
model_config
.
input_size
)
model_config
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_regularizer
=
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
# Build backbone
# Build backbone
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
...
@@ -220,8 +206,7 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -220,8 +206,7 @@ class SimCLRPretrainTask(base_task.Task):
projection_outputs
=
model_outputs
[
simclr_model
.
PROJECTION_OUTPUT_KEY
]
projection_outputs
=
model_outputs
[
simclr_model
.
PROJECTION_OUTPUT_KEY
]
projection1
,
projection2
=
tf
.
split
(
projection_outputs
,
2
,
0
)
projection1
,
projection2
=
tf
.
split
(
projection_outputs
,
2
,
0
)
contrast_loss
,
(
contrast_logits
,
contrast_labels
)
=
con_losses_obj
(
contrast_loss
,
(
contrast_logits
,
contrast_labels
)
=
con_losses_obj
(
projection1
=
projection1
,
projection1
=
projection1
,
projection2
=
projection2
)
projection2
=
projection2
)
contrast_accuracy
=
tf
.
equal
(
contrast_accuracy
=
tf
.
equal
(
tf
.
argmax
(
contrast_labels
,
axis
=
1
),
tf
.
argmax
(
contrast_logits
,
axis
=
1
))
tf
.
argmax
(
contrast_labels
,
axis
=
1
),
tf
.
argmax
(
contrast_logits
,
axis
=
1
))
...
@@ -253,8 +238,8 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -253,8 +238,8 @@ class SimCLRPretrainTask(base_task.Task):
outputs
)
outputs
)
sup_loss
=
tf
.
reduce_mean
(
sup_loss
)
sup_loss
=
tf
.
reduce_mean
(
sup_loss
)
label_acc
=
tf
.
equal
(
tf
.
argmax
(
labels
,
axis
=
1
),
label_acc
=
tf
.
equal
(
tf
.
argmax
(
outputs
,
axis
=
1
))
tf
.
argmax
(
labels
,
axis
=
1
),
tf
.
argmax
(
outputs
,
axis
=
1
))
label_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
label_acc
,
tf
.
float32
))
label_acc
=
tf
.
reduce_mean
(
tf
.
cast
(
label_acc
,
tf
.
float32
))
model_loss
=
contrast_loss
+
sup_loss
model_loss
=
contrast_loss
+
sup_loss
...
@@ -278,10 +263,7 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -278,10 +263,7 @@ class SimCLRPretrainTask(base_task.Task):
if
training
:
if
training
:
metrics
=
[]
metrics
=
[]
metric_names
=
[
metric_names
=
[
'total_loss'
,
'total_loss'
,
'contrast_loss'
,
'contrast_accuracy'
,
'contrast_entropy'
'contrast_loss'
,
'contrast_accuracy'
,
'contrast_entropy'
]
]
if
self
.
task_config
.
model
.
supervised_head
:
if
self
.
task_config
.
model
.
supervised_head
:
metric_names
.
extend
([
'supervised_loss'
,
'accuracy'
])
metric_names
.
extend
([
'supervised_loss'
,
'accuracy'
])
...
@@ -293,18 +275,20 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -293,18 +275,20 @@ class SimCLRPretrainTask(base_task.Task):
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
else
:
else
:
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
return
metrics
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
features
,
labels
=
inputs
features
,
labels
=
inputs
if
(
self
.
task_config
.
model
.
supervised_head
is
not
None
if
(
self
.
task_config
.
model
.
supervised_head
is
not
None
and
and
self
.
task_config
.
evaluation
.
one_hot
):
self
.
task_config
.
evaluation
.
one_hot
):
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
...
@@ -313,8 +297,7 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -313,8 +297,7 @@ class SimCLRPretrainTask(base_task.Task):
outputs
=
model
(
features
,
training
=
True
)
outputs
=
model
(
features
,
training
=
True
)
# Casting output layer as float32 is necessary when mixed_precision is
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs
=
tf
.
nest
.
map_structure
(
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
# Computes per-replica loss.
losses
=
self
.
build_losses
(
losses
=
self
.
build_losses
(
...
@@ -373,7 +356,8 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -373,7 +356,8 @@ class SimCLRPretrainTask(base_task.Task):
class
SimCLRFinetuneTask
(
base_task
.
Task
):
class
SimCLRFinetuneTask
(
base_task
.
Task
):
"""A task for image classification."""
"""A task for image classification."""
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
def
create_optimizer
(
self
,
optimizer_config
:
OptimizationConfig
,
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
runtime_config
:
Optional
[
RuntimeConfig
]
=
None
):
"""Creates an TF optimizer from configurations.
"""Creates an TF optimizer from configurations.
...
@@ -384,8 +368,8 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -384,8 +368,8 @@ class SimCLRFinetuneTask(base_task.Task):
Returns:
Returns:
A tf.optimizers.Optimizer object.
A tf.optimizers.Optimizer object.
"""
"""
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
if
(
optimizer_config
.
optimizer
.
type
==
'lars'
and
and
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
self
.
task_config
.
loss
.
l2_weight_decay
>
0.0
):
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
raise
ValueError
(
'The l2_weight_decay cannot be used together with lars '
'optimizer. Please set it to 0.'
)
'optimizer. Please set it to 0.'
)
...
@@ -403,15 +387,16 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -403,15 +387,16 @@ class SimCLRFinetuneTask(base_task.Task):
def
build_model
(
self
):
def
build_model
(
self
):
model_config
=
self
.
task_config
.
model
model_config
=
self
.
task_config
.
model
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
shape
=
[
None
]
+
model_config
.
input_size
)
model_config
.
input_size
)
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
l2_weight_decay
=
self
.
task_config
.
loss
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_regularizer
=
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
backbone
=
backbones
.
factory
.
build_backbone
(
backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
input_specs
,
input_specs
=
input_specs
,
...
@@ -467,8 +452,8 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -467,8 +452,8 @@ class SimCLRFinetuneTask(base_task.Task):
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
assert_consumed
()
status
.
assert_consumed
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone_projection'
:
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone_projection'
:
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
backbone
,
ckpt
=
tf
.
train
.
Checkpoint
(
projection_head
=
model
.
projection_head
)
backbone
=
model
.
backbone
,
projection_head
=
model
.
projection_head
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
=
ckpt
.
restore
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
elif
self
.
task_config
.
init_checkpoint_modules
==
'backbone'
:
...
@@ -542,12 +527,14 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -542,12 +527,14 @@ class SimCLRFinetuneTask(base_task.Task):
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
else
:
else
:
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))
]
return
metrics
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
...
@@ -577,16 +564,14 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -577,16 +564,14 @@ class SimCLRFinetuneTask(base_task.Task):
# Computes per-replica loss.
# Computes per-replica loss.
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# optimizer.
scaled_loss
=
loss
/
num_replicas
scaled_loss
=
loss
/
num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
# scaled for numerical stability.
if
isinstance
(
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
scaled_loss
=
optimizer
.
get_scaled_loss
(
scaled_loss
)
tvars
=
model
.
trainable_variables
tvars
=
model
.
trainable_variables
...
@@ -596,8 +581,7 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -596,8 +581,7 @@ class SimCLRFinetuneTask(base_task.Task):
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
grads
=
tape
.
gradient
(
scaled_loss
,
tvars
)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
# used.
if
isinstance
(
if
isinstance
(
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
optimizer
,
tf
.
keras
.
mixed_precision
.
LossScaleOptimizer
):
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
grads
=
optimizer
.
get_unscaled_gradients
(
grads
)
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
optimizer
.
apply_gradients
(
list
(
zip
(
grads
,
tvars
)))
...
@@ -626,11 +610,11 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -626,11 +610,11 @@ class SimCLRFinetuneTask(base_task.Task):
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
num_classes
=
self
.
task_config
.
model
.
supervised_head
.
num_classes
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
labels
=
tf
.
one_hot
(
labels
,
num_classes
)
outputs
=
self
.
inference_step
(
outputs
=
self
.
inference_step
(
features
,
features
,
model
)[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
model
)[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
loss
=
self
.
build_losses
(
labels
=
labels
,
aux_losses
=
model
.
losses
)
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
if
metrics
:
...
...
official/vision/beta/projects/simclr/train.py
View file @
59e1ab8a
...
@@ -12,22 +12,7 @@
...
@@ -12,22 +12,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision SimCLR trainer."""
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
import
gin
import
gin
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment