Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
485b25b4
Commit
485b25b4
authored
Jul 19, 2021
by
A. Unique TensorFlower
Browse files
Remove unused use_normalization variable, and move tf.split into if branch.
PiperOrigin-RevId: 385611054
parent
d744a7c1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
7 additions
and
7 deletions
+7
-7
official/vision/beta/projects/simclr/heads/simclr_head.py
official/vision/beta/projects/simclr/heads/simclr_head.py
+0
-1
official/vision/beta/projects/simclr/modeling/simclr_model.py
...cial/vision/beta/projects/simclr/modeling/simclr_model.py
+7
-6
No files found.
official/vision/beta/projects/simclr/heads/simclr_head.py
View file @
485b25b4
...
@@ -97,7 +97,6 @@ class ProjectionHead(tf.keras.layers.Layer):
...
@@ -97,7 +97,6 @@ class ProjectionHead(tf.keras.layers.Layer):
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_initializer'
:
self
.
_kernel_initializer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'kernel_regularizer'
:
self
.
_kernel_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'bias_regularizer'
:
self
.
_bias_regularizer
,
'use_normalization'
:
self
.
_use_normalization
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_momentum'
:
self
.
_norm_momentum
,
'norm_epsilon'
:
self
.
_norm_epsilon
'norm_epsilon'
:
self
.
_norm_epsilon
}
}
...
...
official/vision/beta/projects/simclr/modeling/simclr_model.py
View file @
485b25b4
...
@@ -90,14 +90,15 @@ class SimCLRModel(tf.keras.Model):
...
@@ -90,14 +90,15 @@ class SimCLRModel(tf.keras.Model):
if
training
and
self
.
_mode
==
PRETRAIN
:
if
training
and
self
.
_mode
==
PRETRAIN
:
num_transforms
=
2
num_transforms
=
2
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list
=
tf
.
split
(
inputs
,
num_or_size_splits
=
num_transforms
,
axis
=-
1
)
# (num_transforms * bsz, h, w, c)
features
=
tf
.
concat
(
features_list
,
0
)
else
:
else
:
num_transforms
=
1
num_transforms
=
1
features
=
inputs
# Split channels, and optionally apply extra batched augmentation.
# (bsz, h, w, c*num_transforms) -> [(bsz, h, w, c), ....]
features_list
=
tf
.
split
(
inputs
,
num_or_size_splits
=
num_transforms
,
axis
=-
1
)
# (num_transforms * bsz, h, w, c)
features
=
tf
.
concat
(
features_list
,
0
)
# Base network forward pass.
# Base network forward pass.
endpoints
=
self
.
_backbone
(
features
,
training
=
training
)
endpoints
=
self
.
_backbone
(
features
,
training
=
training
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment