Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
423b506a
Commit
423b506a
authored
Jul 18, 2022
by
Liangzhe Yuan
Committed by
A. Unique TensorFlower
Jul 18, 2022
Browse files
Support 'none' classifier type to ViT. Also rename 'classifier' to 'pooler' for better naming.
PiperOrigin-RevId: 461706948
parent
1071c54e
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
33 additions
and
12 deletions
+33
-12
official/projects/vit/configs/backbones.py
official/projects/vit/configs/backbones.py
+1
-1
official/projects/vit/modeling/vit.py
official/projects/vit/modeling/vit.py
+15
-10
official/projects/vit/modeling/vit_test.py
official/projects/vit/modeling/vit_test.py
+17
-1
No files found.
official/projects/vit/configs/backbones.py
View file @
423b506a
...
...
@@ -34,7 +34,7 @@ class VisionTransformer(hyperparams.Config):
"""VisionTransformer config."""
model_name
:
str
=
'vit-b16'
# pylint: disable=line-too-long
classifi
er
:
str
=
'token'
# 'token'
or
'gap'. If set to 'token', an extra classification token is added to sequence.
pool
er
:
str
=
'token'
# 'token'
,
'gap'
or 'none'
. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size
:
int
=
0
hidden_size
:
int
=
1
...
...
official/projects/vit/modeling/vit.py
View file @
423b506a
...
...
@@ -258,7 +258,7 @@ class VisionTransformer(tf.keras.Model):
patch_size
=
16
,
hidden_size
=
768
,
representation_size
=
0
,
classifi
er
=
'token'
,
pool
er
=
'token'
,
kernel_regularizer
=
None
,
original_init
:
bool
=
True
,
pos_embed_shape
:
Optional
[
Tuple
[
int
,
int
]]
=
None
):
...
...
@@ -289,7 +289,7 @@ class VisionTransformer(tf.keras.Model):
x
=
tf
.
reshape
(
x
,
[
-
1
,
seq_len
,
hidden_size
])
# If we want to add a class token, add it here.
if
classifi
er
==
'token'
:
if
pool
er
==
'token'
:
x
=
TokenLayer
(
name
=
'cls'
)(
x
)
x
=
Encoder
(
...
...
@@ -305,12 +305,14 @@ class VisionTransformer(tf.keras.Model):
pos_embed_origin_shape
=
pos_embed_shape
,
pos_embed_target_shape
=
pos_embed_target_shape
)(
x
)
if
classifi
er
==
'token'
:
if
pool
er
==
'token'
:
x
=
x
[:,
0
]
elif
classifi
er
==
'gap'
:
elif
pool
er
==
'gap'
:
x
=
tf
.
reduce_mean
(
x
,
axis
=
1
)
elif
pooler
==
'none'
:
x
=
tf
.
identity
(
x
,
name
=
'encoded_tokens'
)
else
:
raise
ValueError
(
f
'unrecognized
classifi
er type:
{
classifi
er
}
'
)
raise
ValueError
(
f
'unrecognized
pool
er type:
{
pool
er
}
'
)
if
representation_size
:
x
=
tf
.
keras
.
layers
.
Dense
(
...
...
@@ -322,11 +324,14 @@ class VisionTransformer(tf.keras.Model):
x
=
tf
.
nn
.
tanh
(
x
)
else
:
x
=
tf
.
identity
(
x
,
name
=
'pre_logits'
)
endpoints
=
{
'pre_logits'
:
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
representation_size
or
hidden_size
])
}
if
pooler
==
'none'
:
endpoints
=
{
'encoded_tokens'
:
x
}
else
:
endpoints
=
{
'pre_logits'
:
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
representation_size
or
hidden_size
])
}
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
...
...
@@ -354,7 +359,7 @@ def build_vit(input_specs,
patch_size
=
backbone_cfg
.
patch_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
representation_size
=
backbone_cfg
.
representation_size
,
classifi
er
=
backbone_cfg
.
classifi
er
,
pool
er
=
backbone_cfg
.
pool
er
,
kernel_regularizer
=
l2_regularizer
,
original_init
=
backbone_cfg
.
original_init
,
pos_embed_shape
=
backbone_cfg
.
pos_embed_shape
)
official/projects/vit/modeling/vit_test.py
View file @
423b506a
...
...
@@ -37,6 +37,22 @@ class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
_
=
network
(
inputs
)
self
.
assertEqual
(
network
.
count_params
(),
params_count
)
def
test_network_none_pooler
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_size
=
256
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
2
,
input_size
,
input_size
,
3
])
network
=
vit
.
VisionTransformer
(
input_specs
=
input_specs
,
patch_size
=
16
,
pooler
=
'none'
,
representation_size
=
128
,
pos_embed_shape
=
(
14
,
14
))
# (224 // 16)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
output
=
network
(
inputs
)[
'encoded_tokens'
]
self
.
assertEqual
(
output
.
shape
,
[
1
,
256
,
128
])
def
test_posembedding_interpolation
(
self
):
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_size
=
256
...
...
@@ -45,7 +61,7 @@ class VisionTransformerTest(parameterized.TestCase, tf.test.TestCase):
network
=
vit
.
VisionTransformer
(
input_specs
=
input_specs
,
patch_size
=
16
,
classifi
er
=
'gap'
,
pool
er
=
'gap'
,
pos_embed_shape
=
(
14
,
14
))
# (224 // 16)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment