Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
aba78478
Commit
aba78478
authored
Mar 29, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 365713370
parent
f3f3ec34
Changes
24
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
986 additions
and
0 deletions
+986
-0
official/vision/beta/projects/simclr/modeling/simclr_model.py
...cial/vision/beta/projects/simclr/modeling/simclr_model.py
+177
-0
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
...vision/beta/projects/simclr/modeling/simclr_model_test.py
+87
-0
official/vision/beta/projects/simclr/tasks/simclr.py
official/vision/beta/projects/simclr/tasks/simclr.py
+640
-0
official/vision/beta/projects/simclr/train.py
official/vision/beta/projects/simclr/train.py
+82
-0
No files found.
official/vision/beta/projects/simclr/modeling/simclr_model.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Build simclr models."""
from
typing
import
Optional
from
absl
import
logging
import
tensorflow
as
tf
layers
=
tf
.
keras
.
layers
PRETRAIN
=
'pretrain'
FINETUNE
=
'finetune'
PROJECTION_OUTPUT_KEY
=
'projection_outputs'
SUPERVISED_OUTPUT_KEY
=
'supervised_outputs'
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
'simclr'
)
class
SimCLRModel
(
tf
.
keras
.
Model
):
"""A classification model based on SimCLR framework."""
def
__init__
(
self
,
backbone
:
tf
.
keras
.
models
.
Model
,
projection_head
:
tf
.
keras
.
layers
.
Layer
,
supervised_head
:
Optional
[
tf
.
keras
.
layers
.
Layer
]
=
None
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
mode
:
str
=
PRETRAIN
,
backbone_trainable
:
bool
=
True
,
**
kwargs
):
"""A classification model based on SimCLR framework.
Args:
backbone: a backbone network.
projection_head: a projection head network.
supervised_head: a head network for supervised learning, e.g.
classification head.
input_specs: `tf.keras.layers.InputSpec` specs of the input tensor.
mode: `str` indicates mode of training to be executed.
backbone_trainable: `bool` whether the backbone is trainable or not.
**kwargs: keyword arguments to be passed.
"""
super
(
SimCLRModel
,
self
).
__init__
(
**
kwargs
)
self
.
_config_dict
=
{
'backbone'
:
backbone
,
'projection_head'
:
projection_head
,
'supervised_head'
:
supervised_head
,
'input_specs'
:
input_specs
,
'mode'
:
mode
,
'backbone_trainable'
:
backbone_trainable
,
}
self
.
_input_specs
=
input_specs
self
.
_backbone
=
backbone
self
.
_projection_head
=
projection_head
self
.
_supervised_head
=
supervised_head
self
.
_mode
=
mode
self
.
_backbone_trainable
=
backbone_trainable
# Set whether the backbone is trainable
self
.
_backbone
.
trainable
=
backbone_trainable
def
call
(
self
,
inputs
,
training
=
None
,
**
kwargs
):
model_outputs
=
{}
if
training
and
self
.
_mode
==
PRETRAIN
:
num_transforms
=
2
else
:
num_transforms
=
1
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list
=
tf
.
split
(
inputs
,
num_or_size_splits
=
num_transforms
,
axis
=-
1
)
# (num_transforms * bsz, h, w, c)
features
=
tf
.
concat
(
features_list
,
0
)
# Base network forward pass.
endpoints
=
self
.
_backbone
(
features
,
training
=
training
)
features
=
endpoints
[
max
(
endpoints
.
keys
())]
projection_inputs
=
layers
.
GlobalAveragePooling2D
()(
features
)
# Add heads.
projection_outputs
,
supervised_inputs
=
self
.
_projection_head
(
projection_inputs
,
training
)
if
self
.
_supervised_head
is
not
None
:
if
self
.
_mode
==
PRETRAIN
:
logging
.
info
(
'Ignoring gradient from supervised outputs !'
)
# When performing pretraining and supervised_head together, we do not
# want information from supervised evaluation flowing back into
# pretraining network. So we put a stop_gradient.
supervised_outputs
=
self
.
_supervised_head
(
tf
.
stop_gradient
(
supervised_inputs
),
training
)
else
:
supervised_outputs
=
self
.
_supervised_head
(
supervised_inputs
,
training
)
else
:
supervised_outputs
=
None
model_outputs
.
update
({
PROJECTION_OUTPUT_KEY
:
projection_outputs
,
SUPERVISED_OUTPUT_KEY
:
supervised_outputs
})
return
model_outputs
@
property
def
checkpoint_items
(
self
):
"""Returns a dictionary of items to be additionally checkpointed."""
if
self
.
_supervised_head
is
not
None
:
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
,
supervised_head
=
self
.
supervised_head
)
else
:
items
=
dict
(
backbone
=
self
.
backbone
,
projection_head
=
self
.
projection_head
)
return
items
@
property
def
backbone
(
self
):
return
self
.
_backbone
@
property
def
projection_head
(
self
):
return
self
.
_projection_head
@
property
def
supervised_head
(
self
):
return
self
.
_supervised_head
@
property
def
mode
(
self
):
return
self
.
_mode
@
mode
.
setter
def
mode
(
self
,
value
):
self
.
_mode
=
value
@
property
def
backbone_trainable
(
self
):
return
self
.
_backbone_trainable
@
backbone_trainable
.
setter
def
backbone_trainable
(
self
,
value
):
self
.
_backbone_trainable
=
value
self
.
_backbone
.
trainable
=
value
def
get_config
(
self
):
return
self
.
_config_dict
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
class
SimCLRModelTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
128
,
3
,
0
),
(
128
,
3
,
1
),
(
128
,
1
,
0
),
(
128
,
1
,
1
),
)
def
test_model_creation
(
self
,
project_dim
,
num_proj_layers
,
ft_proj_idx
):
input_size
=
224
inputs
=
np
.
random
.
rand
(
2
,
input_size
,
input_size
,
3
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
None
,
input_size
,
input_size
,
3
])
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
backbone
=
backbones
.
ResNet
(
model_id
=
50
,
activation
=
'relu'
,
input_specs
=
input_specs
)
projection_head
=
simclr_head
.
ProjectionHead
(
proj_output_dim
=
project_dim
,
num_proj_layers
=
num_proj_layers
,
ft_proj_idx
=
ft_proj_idx
)
num_classes
=
10
supervised_head
=
simclr_head
.
ClassificationHead
(
num_classes
=
10
)
model
=
simclr_model
.
SimCLRModel
(
input_specs
=
input_specs
,
backbone
=
backbone
,
projection_head
=
projection_head
,
supervised_head
=
supervised_head
,
mode
=
simclr_model
.
PRETRAIN
)
outputs
=
model
(
inputs
)
projection_outputs
=
outputs
[
simclr_model
.
PROJECTION_OUTPUT_KEY
]
supervised_outputs
=
outputs
[
simclr_model
.
SUPERVISED_OUTPUT_KEY
]
self
.
assertAllEqual
(
projection_outputs
.
shape
.
as_list
(),
[
2
,
project_dim
])
self
.
assertAllEqual
([
2
,
num_classes
],
supervised_outputs
.
numpy
().
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/simclr/tasks/simclr.py
0 → 100644
View file @
aba78478
This diff is collapsed.
Click to expand it.
official/vision/beta/projects/simclr/train.py
0 → 100644
View file @
aba78478
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Model Garden Vision SimCLR training driver."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
task_factory
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.vision.beta.projects.simclr.common
import
registry_imports
# pylint: disable=unused-import
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
print
(
FLAGS
.
experiment
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
,
params
.
runtime
.
loss_scale
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
)
with
distribution_strategy
.
scope
():
task
=
task_factory
.
get_task
(
params
.
task
,
logging_dir
=
model_dir
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment