Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8c7699ce
Commit
8c7699ce
authored
May 04, 2022
by
Jiayu Ye
Committed by
A. Unique TensorFlower
May 04, 2022
Browse files
Internal change
PiperOrigin-RevId: 446536595
parent
71277b49
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
118 additions
and
0 deletions
+118
-0
official/nlp/modeling/layers/cls_head.py
official/nlp/modeling/layers/cls_head.py
+94
-0
official/nlp/modeling/layers/cls_head_test.py
official/nlp/modeling/layers/cls_head_test.py
+24
-0
No files found.
official/nlp/modeling/layers/cls_head.py
View file @
8c7699ce
...
...
@@ -364,3 +364,97 @@ def extract_spec_norm_kwargs(kwargs):
return
dict
(
iteration
=
kwargs
.
pop
(
"iteration"
,
1
),
norm_multiplier
=
kwargs
.
pop
(
"norm_multiplier"
,
.
99
))
class
PerQueryDenseHead
(
tf
.
keras
.
layers
.
Layer
):
"""Pooling head used for EncT5 style models.
This module projects each query to use a different projection.
For a input shape= [bs, num_queries, hidden_size], it projects each query to
(features). Ending up with shape= [bs, num_queries, features].
For example, for classification with a few classes, one may use num_queries
as 1 and features as number of classes. For multilabel classification, one
may use num_queries as number of classes and features as 2. So each query
represents a binary classification of one label.
"""
def
__init__
(
self
,
num_queries
:
int
,
features
:
int
,
use_bias
:
bool
=
False
,
kernel_initializer
:
str
=
"glorot_uniform"
,
**
kwargs
):
"""Initializes the `PerQueryDenseHead`.
Args:
num_queries: number of queries (the learnable embeddings in the input
sequences) from the decoder.
features: int with numbers of output features. Each query with be
projected to this number with a different projection.
use_bias: whether to add a bias to the output.
kernel_initializer: Initializer for dense layer kernels.
**kwargs: Keyword arguments.
"""
super
().
__init__
(
**
kwargs
)
self
.
num_queries
=
num_queries
self
.
features
=
features
self
.
use_bias
=
use_bias
self
.
kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
def
build
(
self
,
input_shape
):
input_shape
=
tf
.
TensorShape
(
input_shape
)
# Hidden size.
last_dim
=
tf
.
compat
.
dimension_value
(
input_shape
[
-
1
])
self
.
hidden_size
=
last_dim
self
.
kernel
=
self
.
add_weight
(
"kernel"
,
shape
=
[
self
.
num_queries
,
last_dim
,
self
.
features
],
initializer
=
self
.
kernel_initializer
,
dtype
=
self
.
dtype
,
trainable
=
True
)
if
self
.
use_bias
:
self
.
bias
=
self
.
add_weight
(
"bias"
,
shape
=
[
self
.
num_queries
,
self
.
features
,
],
dtype
=
self
.
dtype
,
trainable
=
True
)
else
:
self
.
bias
=
None
def
call
(
self
,
inputs
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Implements call().
Args:
inputs: a rank-3 Tensor of shape= [bs, num_queries, hidden_size].
Returns:
A Tensor, shape= [batch size, num_queries, features].
"""
outputs
=
tf
.
einsum
(
"bqh,qhf->bqf"
,
inputs
,
self
.
kernel
)
if
self
.
use_bias
:
outputs
+=
self
.
bias
return
outputs
def
get_config
(
self
):
config
=
{
"num_queries"
:
self
.
num_queries
,
"features"
:
self
.
features
,
"kernel_initializer"
:
tf
.
keras
.
activations
.
serialize
(
self
.
kernel_initializer
),
}
config
.
update
(
super
(
PerQueryDenseHead
,
self
).
get_config
())
return
config
@
classmethod
def
from_config
(
cls
,
config
,
custom_objects
=
None
):
return
cls
(
**
config
)
official/nlp/modeling/layers/cls_head_test.py
View file @
8c7699ce
...
...
@@ -199,5 +199,29 @@ class GaussianProcessClassificationHead(tf.test.TestCase,
self
.
assertEqual
(
layer_config
[
"norm_multiplier"
],
1.
)
self
.
assertEqual
(
layer_config
[
"num_inducing"
],
512
)
class
PerQueryDenseHeadTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
"single_query"
,
1
,
3
,
False
),
(
"multi_queries"
,
10
,
2
,
False
),
(
"with_bias"
,
10
,
2
,
True
))
def
test_layer_invocation
(
self
,
num_queries
,
features
,
use_bias
):
batch_size
=
5
hidden_size
=
10
layer
=
cls_head
.
PerQueryDenseHead
(
num_queries
=
num_queries
,
features
=
features
,
use_bias
=
use_bias
)
inputs
=
tf
.
zeros
(
shape
=
(
batch_size
,
num_queries
,
hidden_size
),
dtype
=
tf
.
float32
)
outputs
=
layer
(
inputs
)
self
.
assertEqual
(
outputs
.
shape
,
[
batch_size
,
num_queries
,
features
])
def
test_layer_serialization
(
self
):
layer
=
cls_head
.
PerQueryDenseHead
(
num_queries
=
10
,
features
=
2
,
use_bias
=
True
)
new_layer
=
cls_head
.
PerQueryDenseHead
.
from_config
(
layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
layer
.
get_config
(),
new_layer
.
get_config
())
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment