Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9465aa0e
Commit
9465aa0e
authored
Jul 27, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 387142853
parent
2de518be
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
325 additions
and
0 deletions
+325
-0
official/vision/beta/projects/simclr/configs/multitask_config.py
...l/vision/beta/projects/simclr/configs/multitask_config.py
+70
-0
official/vision/beta/projects/simclr/configs/multitask_config_test.py
...ion/beta/projects/simclr/configs/multitask_config_test.py
+39
-0
official/vision/beta/projects/simclr/modeling/multitask_model.py
...l/vision/beta/projects/simclr/modeling/multitask_model.py
+99
-0
official/vision/beta/projects/simclr/modeling/multitask_model_test.py
...ion/beta/projects/simclr/modeling/multitask_model_test.py
+43
-0
official/vision/beta/projects/simclr/multitask_train.py
official/vision/beta/projects/simclr/multitask_train.py
+74
-0
No files found.
official/vision/beta/projects/simclr/configs/multitask_config.py
0 → 100644
View file @
9465aa0e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-task SimCLR configs."""
import
dataclasses
from
typing
import
List
,
Tuple
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling.multitask
import
configs
as
multitask_configs
from
official.vision.beta.configs
import
backbones
from
official.vision.beta.configs
import
common
from
official.vision.beta.projects.simclr.configs
import
simclr
as
simclr_configs
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
@
dataclasses
.
dataclass
class
SimCLRMTHeadConfig
(
hyperparams
.
Config
):
"""Per-task specific configs."""
# Supervised head is required for finetune, but optional for pretrain.
supervised_head
:
simclr_configs
.
SupervisedHead
=
simclr_configs
.
SupervisedHead
(
num_classes
=
1001
)
mode
:
str
=
simclr_model
.
PRETRAIN
@
dataclasses
.
dataclass
class
SimCLRMTModelConfig
(
hyperparams
.
Config
):
"""Model config for multi-task SimCLR model."""
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'resnet'
,
resnet
=
backbones
.
ResNet
())
backbone_trainable
:
bool
=
True
projection_head
:
simclr_configs
.
ProjectionHead
=
simclr_configs
.
ProjectionHead
(
proj_output_dim
=
128
,
num_proj_layers
=
3
,
ft_proj_idx
=
1
)
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
norm_momentum
=
0.9
,
norm_epsilon
=
1e-5
,
use_sync_bn
=
False
)
heads
:
Tuple
[
SimCLRMTHeadConfig
,
...]
=
()
# L2 weight decay is used in the model, not in task.
# Note that this can not be used together with lars optimizer.
l2_weight_decay
:
float
=
0.0
@
exp_factory
.
register_config_factory
(
'multitask_simclr'
)
def
multitask_simclr
()
->
multitask_configs
.
MultiTaskExperimentConfig
:
return
multitask_configs
.
MultiTaskExperimentConfig
(
task
=
multitask_configs
.
MultiTaskConfig
(
model
=
SimCLRMTModelConfig
(
heads
=
(
SimCLRMTHeadConfig
(
mode
=
simclr_model
.
PRETRAIN
),
SimCLRMTHeadConfig
(
mode
=
simclr_model
.
FINETUNE
))),
task_routines
=
(
multitask_configs
.
TaskRoutine
(
task_name
=
simclr_model
.
PRETRAIN
,
task_config
=
simclr_configs
.
SimCLRPretrainTask
(),
task_weight
=
2.0
),
multitask_configs
.
TaskRoutine
(
task_name
=
simclr_model
.
FINETUNE
,
task_config
=
simclr_configs
.
SimCLRFinetuneTask
(),
task_weight
=
1.0
))),
trainer
=
multitask_configs
.
MultiTaskTrainerConfig
())
official/vision/beta/projects/simclr/configs/multitask_config_test.py
0 → 100644
View file @
9465aa0e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask_config."""
import
tensorflow
as
tf
from
official.core
import
exp_factory
from
official.modeling.multitask
import
configs
as
multitask_configs
from
official.vision.beta.projects.simclr.configs
import
multitask_config
as
simclr_multitask_config
from
official.vision.beta.projects.simclr.configs
import
simclr
as
exp_cfg
class
MultitaskConfigTest
(
tf
.
test
.
TestCase
):
def
test_simclr_configs
(
self
):
config
=
exp_factory
.
get_exp_config
(
'multitask_simclr'
)
self
.
assertIsInstance
(
config
,
multitask_configs
.
MultiTaskExperimentConfig
)
self
.
assertIsInstance
(
config
.
task
.
model
,
simclr_multitask_config
.
SimCLRMTModelConfig
)
self
.
assertIsInstance
(
config
.
task
.
task_routines
[
0
].
task_config
,
exp_cfg
.
SimCLRPretrainTask
)
self
.
assertIsInstance
(
config
.
task
.
task_routines
[
1
].
task_config
,
exp_cfg
.
SimCLRFinetuneTask
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/simclr/modeling/multitask_model.py
0 → 100644
View file @
9465aa0e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multi-task image multi-taskSimCLR model definition."""
from
typing
import
Dict
,
Text
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.projects.simclr.configs
import
multitask_config
as
simclr_multitask_config
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
PROJECTION_OUTPUT_KEY
=
'projection_outputs'
SUPERVISED_OUTPUT_KEY
=
'supervised_outputs'
class
SimCLRMTModel
(
base_model
.
MultiTaskBaseModel
):
"""A multi-task SimCLR model that does both pretrain and finetune."""
def
__init__
(
self
,
config
:
simclr_multitask_config
.
SimCLRMTModelConfig
,
**
kwargs
):
self
.
_config
=
config
# Build shared backbone.
self
.
_input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
]
+
config
.
input_size
)
l2_weight_decay
=
config
.
l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
self
.
_l2_regularizer
=
(
tf
.
keras
.
regularizers
.
l2
(
l2_weight_decay
/
2.0
)
if
l2_weight_decay
else
None
)
self
.
_backbone
=
backbones
.
factory
.
build_backbone
(
input_specs
=
self
.
_input_specs
,
backbone_config
=
config
.
backbone
,
norm_activation_config
=
config
.
norm_activation
,
l2_regularizer
=
self
.
_l2_regularizer
)
super
().
__init__
(
**
kwargs
)
def
_instantiate_sub_tasks
(
self
)
->
Dict
[
Text
,
tf
.
keras
.
Model
]:
tasks
=
{}
# Build the shared projection head
norm_activation_config
=
self
.
_config
.
norm_activation
projection_head_config
=
self
.
_config
.
projection_head
projection_head
=
simclr_head
.
ProjectionHead
(
proj_output_dim
=
projection_head_config
.
proj_output_dim
,
num_proj_layers
=
projection_head_config
.
num_proj_layers
,
ft_proj_idx
=
projection_head_config
.
ft_proj_idx
,
kernel_regularizer
=
self
.
_l2_regularizer
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
norm_momentum
=
norm_activation_config
.
norm_momentum
,
norm_epsilon
=
norm_activation_config
.
norm_epsilon
)
for
model_config
in
self
.
_config
.
heads
:
# Build supervised head
supervised_head_config
=
model_config
.
supervised_head
if
supervised_head_config
:
if
supervised_head_config
.
zero_init
:
s_kernel_initializer
=
'zeros'
else
:
s_kernel_initializer
=
'random_uniform'
supervised_head
=
simclr_head
.
ClassificationHead
(
num_classes
=
supervised_head_config
.
num_classes
,
kernel_initializer
=
s_kernel_initializer
,
kernel_regularizer
=
self
.
_l2_regularizer
)
else
:
supervised_head
=
None
tasks
[
model_config
.
mode
]
=
simclr_model
.
SimCLRModel
(
input_specs
=
self
.
_input_specs
,
backbone
=
self
.
_backbone
,
projection_head
=
projection_head
,
supervised_head
=
supervised_head
,
mode
=
model_config
.
mode
,
backbone_trainable
=
self
.
_config
.
backbone_trainable
)
return
tasks
# TODO(huythong): Implement initialize function to load the pretrained
# checkpoint of backbone.
# def initialize(self):
official/vision/beta/projects/simclr/modeling/multitask_model_test.py
0 → 100644
View file @
9465aa0e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for multitask_model."""
import
os.path
import
tensorflow
as
tf
from
official.vision.beta.projects.simclr.configs
import
multitask_config
from
official.vision.beta.projects.simclr.modeling
import
multitask_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
class
MultitaskModelTest
(
tf
.
test
.
TestCase
):
def
test_initialize_model_success
(
self
):
ckpt_dir
=
self
.
get_temp_dir
()
config
=
multitask_config
.
SimCLRMTModelConfig
(
input_size
=
[
64
,
64
,
3
],
heads
=
(
multitask_config
.
SimCLRMTHeadConfig
(
mode
=
simclr_model
.
PRETRAIN
),
multitask_config
.
SimCLRMTHeadConfig
(
mode
=
simclr_model
.
FINETUNE
)))
model
=
multitask_model
.
SimCLRMTModel
(
config
)
self
.
assertIn
(
simclr_model
.
PRETRAIN
,
model
.
sub_tasks
)
self
.
assertIn
(
simclr_model
.
FINETUNE
,
model
.
sub_tasks
)
ckpt
=
tf
.
train
.
Checkpoint
(
backbone
=
model
.
_backbone
)
ckpt
.
save
(
os
.
path
.
join
(
ckpt_dir
,
'ckpt'
))
model
.
initialize
()
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/simclr/multitask_train.py
0 → 100644
View file @
9465aa0e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer binary for multitask simclr."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling.multitask
import
multitask
from
official.modeling.multitask
import
train_lib
# pylint: disable=unused-import
from
official.vision.beta.projects.simclr.common
import
registry_imports
from
official.vision.beta.projects.simclr.configs
import
multitask_config
from
official.vision.beta.projects.simclr.modeling
import
multitask_model
# pylint: enable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
tasks
=
multitask
.
MultiTask
.
from_config
(
params
.
task
)
model
=
multitask_model
.
SimCLRMTModel
(
params
.
task
.
model
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
tasks
,
model
=
model
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment