Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
96faaea8
Commit
96faaea8
authored
Mar 07, 2022
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Mar 07, 2022
Browse files
Internal change
PiperOrigin-RevId: 433133947
parent
d6e3a60f
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
21 additions
and
18 deletions
+21
-18
official/vision/beta/projects/simclr/common/registry_imports.py
...al/vision/beta/projects/simclr/common/registry_imports.py
+1
-1
official/vision/beta/projects/simclr/configs/multitask_config.py
...l/vision/beta/projects/simclr/configs/multitask_config.py
+2
-2
official/vision/beta/projects/simclr/configs/simclr.py
official/vision/beta/projects/simclr/configs/simclr.py
+2
-2
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
...l/vision/beta/projects/simclr/dataloaders/simclr_input.py
+3
-3
official/vision/beta/projects/simclr/modeling/multitask_model.py
...l/vision/beta/projects/simclr/modeling/multitask_model.py
+5
-4
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
...vision/beta/projects/simclr/modeling/simclr_model_test.py
+1
-2
official/vision/beta/projects/simclr/tasks/simclr.py
official/vision/beta/projects/simclr/tasks/simclr.py
+7
-4
No files found.
official/vision/beta/projects/simclr/common/registry_imports.py
View file @
96faaea8
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
"""All necessary imports for registration."""
"""All necessary imports for registration."""
# pylint: disable=unused-import
# pylint: disable=unused-import
from
official.
comm
on
import
registry_imports
from
official.
visi
on
import
registry_imports
from
official.vision.beta.projects.simclr.configs
import
simclr
from
official.vision.beta.projects.simclr.configs
import
simclr
from
official.vision.beta.projects.simclr.losses
import
contrastive_losses
from
official.vision.beta.projects.simclr.losses
import
contrastive_losses
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
...
...
official/vision/beta/projects/simclr/configs/multitask_config.py
View file @
96faaea8
...
@@ -20,10 +20,10 @@ from typing import List, Tuple
...
@@ -20,10 +20,10 @@ from typing import List, Tuple
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling.multitask
import
configs
as
multitask_configs
from
official.modeling.multitask
import
configs
as
multitask_configs
from
official.vision.beta.configs
import
backbones
from
official.vision.beta.configs
import
common
from
official.vision.beta.projects.simclr.configs
import
simclr
as
simclr_configs
from
official.vision.beta.projects.simclr.configs
import
simclr
as
simclr_configs
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.configs
import
backbones
from
official.vision.configs
import
common
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/simclr/configs/simclr.py
View file @
96faaea8
...
@@ -21,9 +21,9 @@ from official.core import config_definitions as cfg
...
@@ -21,9 +21,9 @@ from official.core import config_definitions as cfg
from
official.core
import
exp_factory
from
official.core
import
exp_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.vision.beta.configs
import
backbones
from
official.vision.beta.configs
import
common
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.configs
import
backbones
from
official.vision.configs
import
common
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/simclr/dataloaders/simclr_input.py
View file @
96faaea8
...
@@ -40,11 +40,11 @@ from typing import List
...
@@ -40,11 +40,11 @@ from typing import List
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.projects.simclr.dataloaders
import
preprocess_ops
as
simclr_preprocess_ops
from
official.vision.beta.projects.simclr.dataloaders
import
preprocess_ops
as
simclr_preprocess_ops
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.dataloaders
import
decoder
from
official.vision.dataloaders
import
parser
from
official.vision.ops
import
preprocess_ops
class
Decoder
(
decoder
.
Decoder
):
class
Decoder
(
decoder
.
Decoder
):
...
...
official/vision/beta/projects/simclr/modeling/multitask_model.py
View file @
96faaea8
...
@@ -14,15 +14,15 @@
...
@@ -14,15 +14,15 @@
"""Multi-task image multi-taskSimCLR model definition."""
"""Multi-task image multi-taskSimCLR model definition."""
from
typing
import
Dict
,
Text
from
typing
import
Dict
,
Text
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling.multitask
import
base_model
from
official.modeling.multitask
import
base_model
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.projects.simclr.configs
import
multitask_config
as
simclr_multitask_config
from
official.vision.beta.projects.simclr.configs
import
multitask_config
as
simclr_multitask_config
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.modeling
import
backbones
PROJECTION_OUTPUT_KEY
=
'projection_outputs'
PROJECTION_OUTPUT_KEY
=
'projection_outputs'
SUPERVISED_OUTPUT_KEY
=
'supervised_outputs'
SUPERVISED_OUTPUT_KEY
=
'supervised_outputs'
...
@@ -110,7 +110,8 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
...
@@ -110,7 +110,8 @@ class SimCLRMTModel(base_model.MultiTaskBaseModel):
pretrained_items
=
dict
(
pretrained_items
=
dict
(
backbone
=
self
.
_backbone
,
projection_head
=
self
.
_projection_head
)
backbone
=
self
.
_backbone
,
projection_head
=
self
.
_projection_head
)
else
:
else
:
assert
(
"Only 'backbone_projection' or 'backbone' can be used to "
raise
ValueError
(
"Only 'backbone_projection' or 'backbone' can be used to "
'initialize the model.'
)
'initialize the model.'
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrained_items
)
ckpt
=
tf
.
train
.
Checkpoint
(
**
pretrained_items
)
...
...
official/vision/beta/projects/simclr/modeling/simclr_model_test.py
View file @
96faaea8
...
@@ -14,13 +14,12 @@
...
@@ -14,13 +14,12 @@
"""Test for SimCLR model."""
"""Test for SimCLR model."""
from
absl.testing
import
parameterized
from
absl.testing
import
parameterized
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.modeling
import
backbones
class
SimCLRModelTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
class
SimCLRModelTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
...
...
official/vision/beta/projects/simclr/tasks/simclr.py
View file @
96faaea8
...
@@ -36,12 +36,12 @@ from official.core import task_factory
...
@@ -36,12 +36,12 @@ from official.core import task_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.vision.beta.modeling
import
backbones
from
official.vision.beta.projects.simclr.configs
import
simclr
as
exp_cfg
from
official.vision.beta.projects.simclr.configs
import
simclr
as
exp_cfg
from
official.vision.beta.projects.simclr.dataloaders
import
simclr_input
from
official.vision.beta.projects.simclr.dataloaders
import
simclr_input
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.heads
import
simclr_head
from
official.vision.beta.projects.simclr.losses
import
contrastive_losses
from
official.vision.beta.projects.simclr.losses
import
contrastive_losses
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.beta.projects.simclr.modeling
import
simclr_model
from
official.vision.modeling
import
backbones
OptimizationConfig
=
optimization
.
OptimizationConfig
OptimizationConfig
=
optimization
.
OptimizationConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
RuntimeConfig
=
config_definitions
.
RuntimeConfig
...
@@ -157,7 +157,8 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -157,7 +157,8 @@ class SimCLRPretrainTask(base_task.Task):
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
raise
ValueError
(
"Only 'all' or 'backbone' can be used to initialize the model."
)
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
ckpt_dir_or_file
)
...
@@ -335,7 +336,8 @@ class SimCLRPretrainTask(base_task.Task):
...
@@ -335,7 +336,8 @@ class SimCLRPretrainTask(base_task.Task):
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
if
self
.
task_config
.
model
.
supervised_head
is
None
:
if
self
.
task_config
.
model
.
supervised_head
is
None
:
assert
'Skipping eval during pretraining without supervised head.'
raise
ValueError
(
'Skipping eval during pretraining without supervised head.'
)
features
,
labels
=
inputs
features
,
labels
=
inputs
if
self
.
task_config
.
evaluation
.
one_hot
:
if
self
.
task_config
.
evaluation
.
one_hot
:
...
@@ -467,7 +469,8 @@ class SimCLRFinetuneTask(base_task.Task):
...
@@ -467,7 +469,8 @@ class SimCLRFinetuneTask(base_task.Task):
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
=
ckpt
.
read
(
ckpt_dir_or_file
)
status
.
expect_partial
().
assert_existing_objects_matched
()
status
.
expect_partial
().
assert_existing_objects_matched
()
else
:
else
:
assert
"Only 'all' or 'backbone' can be used to initialize the model."
raise
ValueError
(
"Only 'all' or 'backbone' can be used to initialize the model."
)
# If the checkpoint is from pretraining, reset the following parameters
# If the checkpoint is from pretraining, reset the following parameters
model
.
backbone_trainable
=
self
.
task_config
.
model
.
backbone_trainable
model
.
backbone_trainable
=
self
.
task_config
.
model
.
backbone_trainable
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment