Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
a91f3779
Commit
a91f3779
authored
Aug 31, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 393981019
parent
bf9805c5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
10 deletions
+39
-10
official/vision/beta/projects/simclr/configs/multitask_config.py
...l/vision/beta/projects/simclr/configs/multitask_config.py
+3
-0
official/vision/beta/projects/simclr/modeling/multitask_model.py
...l/vision/beta/projects/simclr/modeling/multitask_model.py
+36
-10
No files found.
official/vision/beta/projects/simclr/configs/multitask_config.py
View file @
a91f3779
...
...
@@ -51,6 +51,9 @@ class SimCLRMTModelConfig(hyperparams.Config):
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
l2_weight_decay
:
float
=
0.0
init_checkpoint
:
str
=
''
# backbone_projection or backbone
init_checkpoint_modules
:
str
=
'backbone_projection'
@
exp_factory
.
register_config_factory
(
'multitask_simclr'
)
...
...
official/vision/beta/projects/simclr/modeling/multitask_model.py
View file @
a91f3779
...
...
@@ -14,6 +14,7 @@
"""Multi-task image multi-taskSimCLR model definition."""
from
typing
import
Dict
,
Text
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -52,15 +53,10 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_activation_config
=
config
.
norm_activation
,
l2_regularizer
=
self
.
_l2_regularizer
)
super
().
__init__
(
**
kwargs
)
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
tasks
=
{}
# Build the shared projection head
norm_activation_config
=
self
.
_config
.
norm_activation
projection_head_config
=
self
.
_config
.
projection_head
projection_head
=
simclr_head
.
ProjectionHead
(
self
.
_
projection_head
=
simclr_head
.
ProjectionHead
(
proj_output_dim
=
projection_head_config
.
proj_output_dim
,
num_proj_layers
=
projection_head_config
.
num_proj_layers
,
ft_proj_idx
=
projection_head_config
.
ft_proj_idx
,
...
...
@@ -69,6 +65,11 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
)
super
().
__init__
(
**
kwargs
)
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
tasks
=
{}
for
model_config
in
self
.
_config
.
heads
:
# Build supervised head
supervised_head_config
=
model_config
.
supervised_head
...
...
@@ -87,13 +88,38 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
tasks
[
model_config
.
task_name
]
=
simclr_model
.
SimCLRModel
(
input_specs
=
self
.
_input_specs
,
backbone
=
self
.
_backbone
,
projection_head
=
projection_head
,
projection_head
=
self
.
_
projection_head
,
supervised_head
=
supervised_head
,
mode
=
model_config
.
mode
,
backbone_trainable
=
self
.
_config
.
backbone_trainable
)
return
tasks
# TODO(huythong): Implement initialize function to load the pretrained
# checkpoint of backbone.
# def initialize(self):
def
initialize
(
self
):
"""Loads the multi-task SimCLR model with a pretrained checkpoint."""
ckpt_dir_or_file
=
self
.
_config
.
init_checkpoint
if
tf
.
io
.
gfile
.
isdir
(
ckpt_dir_or_file
):
ckpt_dir_or_file
=
tf
.
train
.
latest_checkpoint
(
ckpt_dir_or_file
)
if
not
ckpt_dir_or_file
:
return
logging
.
info
(
'Loading pretrained %s'
,
self
.
_config
.
init_checkpoint_modules
)
if
self
.
_config
.
init_checkpoint_modules
==
'backbone'
:
pretrained_items
=
dict
(
backbone
=
self
.
_backbone
)
elif
self
.
_config
.
init_checkpoint_modules
==
'backbone_projection'
:
pretrained_items
=
dict
(
backbone
=
self
.
_backbone
,
projection_head
=
self
.
_projection_head
)
else
:
assert
(
"Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.'
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrained_items
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
return
dict
(
backbone
=
self
.
_backbone
,
projection_head
=
self
.
_projection_head
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment