Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
6d28faff
Commit
6d28faff
authored
Aug 25, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 25, 2021
Browse files
Internal change
PiperOrigin-RevId: 392968271
parent
33598c45
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
40 additions
and
6 deletions
+40
-6
official/nlp/modeling/layers/cls_head.py
official/nlp/modeling/layers/cls_head.py
+35
-6
official/nlp/modeling/layers/cls_head_test.py
official/nlp/modeling/layers/cls_head_test.py
+5
-0
No files found.
official/nlp/modeling/layers/cls_head.py
View file @
6d28faff
...
...
@@ -59,19 +59,33 @@ class ClassificationHead(tf.keras.layers.Layer):
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
"pooler_dense"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
out_proj
=
tf
.
keras
.
layers
.
Dense
(
units
=
num_classes
,
kernel_initializer
=
self
.
initializer
,
name
=
"logits"
)
def
call
(
self
,
features
):
def
call
(
self
,
features
:
tf
.
Tensor
,
only_project
:
bool
=
False
):
"""Implements call().
Args:
features: a rank-3 Tensor when self.inner_dim is specified, otherwise
it is a rank-2 Tensor.
only_project: a boolean. If True, we return the intermediate Tensor
before projecting to class logits.
Returns:
a Tensor, if only_project is True, shape= [batch size, hidden size].
If only_project is False, shape= [batch size, num classes].
"""
if
not
self
.
inner_dim
:
x
=
features
else
:
x
=
features
[:,
self
.
cls_token_idx
,
:]
# take <CLS> token.
x
=
self
.
dense
(
x
)
x
=
self
.
dropout
(
x
)
if
only_project
:
return
x
x
=
self
.
dropout
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
...
...
@@ -134,7 +148,7 @@ class MultiClsHeads(tf.keras.layers.Layer):
activation
=
self
.
activation
,
kernel_initializer
=
self
.
initializer
,
name
=
"pooler_dense"
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
dropout_rate
)
self
.
out_projs
=
[]
for
name
,
num_classes
in
cls_list
:
self
.
out_projs
.
append
(
...
...
@@ -142,13 +156,28 @@ class MultiClsHeads(tf.keras.layers.Layer):
units
=
num_classes
,
kernel_initializer
=
self
.
initializer
,
name
=
name
))
def
call
(
self
,
features
):
def
call
(
self
,
features
:
tf
.
Tensor
,
only_project
:
bool
=
False
):
"""Implements call().
Args:
features: a rank-3 Tensor when self.inner_dim is specified, otherwise
it is a rank-2 Tensor.
only_project: a boolean. If True, we return the intermediate Tensor
before projecting to class logits.
Returns:
If only_project is True, a Tensor with shape= [batch size, hidden size].
If only_project is False, a dictionary of Tensors.
"""
if
not
self
.
inner_dim
:
x
=
features
else
:
x
=
features
[:,
self
.
cls_token_idx
,
:]
# take <CLS> token.
x
=
self
.
dense
(
x
)
x
=
self
.
dropout
(
x
)
if
only_project
:
return
x
x
=
self
.
dropout
(
x
)
outputs
=
{}
for
proj_layer
in
self
.
out_projs
:
...
...
official/nlp/modeling/layers/cls_head_test.py
View file @
6d28faff
...
...
@@ -39,6 +39,8 @@ class ClassificationHeadTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertAllClose
(
output
,
[[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertSameElements
(
test_layer
.
checkpoint_items
.
keys
(),
[
"pooler_dense"
])
outputs
=
test_layer
(
features
,
only_project
=
True
)
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
5
))
def
test_layer_serialization
(
self
):
layer
=
cls_head
.
ClassificationHead
(
10
,
2
)
...
...
@@ -71,6 +73,9 @@ class MultiClsHeadsTest(tf.test.TestCase, parameterized.TestCase):
self
.
assertSameElements
(
test_layer
.
checkpoint_items
.
keys
(),
[
"pooler_dense"
,
"foo"
,
"bar"
])
outputs
=
test_layer
(
features
,
only_project
=
True
)
self
.
assertEqual
(
outputs
.
shape
,
(
2
,
5
))
def
test_layer_serialization
(
self
):
cls_list
=
[(
"foo"
,
2
),
(
"bar"
,
3
)]
test_layer
=
cls_head
.
MultiClsHeads
(
inner_dim
=
5
,
cls_list
=
cls_list
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment