Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0dadbbc8
Commit
0dadbbc8
authored
Jan 08, 2022
by
Frederick Liu
Committed by
A. Unique TensorFlower
Jan 08, 2022
Browse files
Internal change
PiperOrigin-RevId: 420497751
parent
993dbf54
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
4 deletions
+10
-4
official/vision/beta/projects/vit/modeling/vit.py
official/vision/beta/projects/vit/modeling/vit.py
+10
-4
No files found.
official/vision/beta/projects/vit/modeling/vit.py
View file @
0dadbbc8
...
...
@@ -21,6 +21,7 @@ from official.vision.beta.modeling.backbones import factory
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.projects.vit.modeling
import
nn_blocks
layers
=
tf
.
keras
.
layers
VIT_SPECS
=
{
...
...
@@ -121,6 +122,7 @@ class Encoder(tf.keras.layers.Layer):
inputs_positions
=
None
,
init_stochastic_depth_rate
=
0.0
,
kernel_initializer
=
'glorot_uniform'
,
add_pos_embed
=
True
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
...
...
@@ -132,11 +134,13 @@ class Encoder(tf.keras.layers.Layer):
self
.
_inputs_positions
=
inputs_positions
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_kernel_initializer
=
kernel_initializer
self
.
_add_pos_embed
=
add_pos_embed
def
build
(
self
,
input_shape
):
self
.
_pos_embed
=
AddPositionEmbs
(
posemb_init
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.02
),
name
=
'posembed_input'
)
if
self
.
_add_pos_embed
:
self
.
_pos_embed
=
AddPositionEmbs
(
posemb_init
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.02
),
name
=
'posembed_input'
)
self
.
_dropout
=
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_encoder_layers
=
[]
...
...
@@ -160,7 +164,9 @@ class Encoder(tf.keras.layers.Layer):
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
x
=
self
.
_pos_embed
(
inputs
,
inputs_positions
=
self
.
_inputs_positions
)
x
=
inputs
if
self
.
_add_pos_embed
:
x
=
self
.
_pos_embed
(
x
,
inputs_positions
=
self
.
_inputs_positions
)
x
=
self
.
_dropout
(
x
,
training
=
training
)
for
encoder_layer
in
self
.
_encoder_layers
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment