Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
572fafb8
Commit
572fafb8
authored
Mar 21, 2021
by
Jeremiah Liu
Committed by
A. Unique TensorFlower
Mar 21, 2021
Browse files
Implements `GaussianProcessClassificationHead`.
PiperOrigin-RevId: 364226289
parent
4785b025
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
161 additions
and
0 deletions
+161
-0
official/nlp/modeling/layers/cls_head.py
official/nlp/modeling/layers/cls_head.py
+110
-0
official/nlp/modeling/layers/cls_head_test.py
official/nlp/modeling/layers/cls_head_test.py
+51
-0
No files found.
official/nlp/modeling/layers/cls_head.py
View file @
572fafb8
...
@@ -18,6 +18,9 @@ import tensorflow as tf
...
@@ -18,6 +18,9 @@ import tensorflow as tf
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.nlp.modeling.layers
import
gaussian_process
from
official.nlp.modeling.layers
import
spectral_normalization
class
ClassificationHead
(
tf
.
keras
.
layers
.
Layer
):
class
ClassificationHead
(
tf
.
keras
.
layers
.
Layer
):
"""Pooling head for sentence-level classification tasks."""
"""Pooling head for sentence-level classification tasks."""
...
@@ -160,3 +163,110 @@ class MultiClsHeads(tf.keras.layers.Layer):
...
@@ -160,3 +163,110 @@ class MultiClsHeads(tf.keras.layers.Layer):
items
=
{
self
.
dense
.
name
:
self
.
dense
}
items
=
{
self
.
dense
.
name
:
self
.
dense
}
items
.
update
({
v
.
name
:
v
for
v
in
self
.
out_projs
})
items
.
update
({
v
.
name
:
v
for
v
in
self
.
out_projs
})
return
items
return
items
class
GaussianProcessClassificationHead
(
ClassificationHead
):
"""Gaussian process-based pooling head for sentence classification.
This class implements a classifier head for BERT encoder that is based on the
spectral-normalized neural Gaussian process (SNGP) [1]. SNGP is a simple
method to improve a neural network's uncertainty quantification ability
without sacrificing accuracy or lantency. It applies spectral normalization to
the hidden pooler layer, and then replaces the dense output layer with a
Gaussian process.
[1]: Jeremiah Liu et al. Simple and Principled Uncertainty Estimation with
Deterministic Deep Learning via Distance Awareness.
In _Neural Information Processing Systems_, 2020.
https://arxiv.org/abs/2006.10108
"""
def
__init__
(
self
,
inner_dim
,
num_classes
,
cls_token_idx
=
0
,
activation
=
"tanh"
,
dropout_rate
=
0.0
,
initializer
=
"glorot_uniform"
,
use_spec_norm
=
True
,
use_gp_layer
=
True
,
**
kwargs
):
"""Initializes the `GaussianProcessClassificationHead`.
Args:
inner_dim: The dimensionality of inner projection layer.
num_classes: Number of output classes.
cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation.
dropout_rate: Dropout probability.
initializer: Initializer for dense layer kernels.
use_spec_norm: Whether to apply spectral normalization to pooler layer.
use_gp_layer: Whether to use Gaussian process as the output layer.
**kwargs: Additional keyword arguments.
"""
# Collects spectral normalization and Gaussian process args from kwargs.
self
.
use_spec_norm
=
use_spec_norm
self
.
use_gp_layer
=
use_gp_layer
self
.
spec_norm_kwargs
=
extract_spec_norm_kwargs
(
kwargs
)
self
.
gp_layer_kwargs
=
extract_gp_layer_kwargs
(
kwargs
)
super
().
__init__
(
inner_dim
=
inner_dim
,
num_classes
=
num_classes
,
cls_token_idx
=
cls_token_idx
,
activation
=
activation
,
dropout_rate
=
dropout_rate
,
initializer
=
initializer
,
**
kwargs
)
# Applies spectral normalization to the pooler layer.
if
use_spec_norm
:
self
.
dense
=
spectral_normalization
.
SpectralNormalization
(
self
.
dense
,
inhere_layer_name
=
True
,
**
self
.
spec_norm_kwargs
)
# Replace Dense output layer with the Gaussian process layer.
if
use_gp_layer
:
self
.
out_proj
=
gaussian_process
.
RandomFeatureGaussianProcess
(
self
.
num_classes
,
kernel_initializer
=
self
.
initializer
,
name
=
"logits"
,
**
self
.
gp_layer_kwargs
)
def
get_config
(
self
):
config
=
dict
(
use_spec_norm
=
self
.
use_spec_norm
,
use_gp_layer
=
self
.
use_gp_layer
)
config
.
update
(
self
.
spec_norm_kwargs
)
config
.
update
(
self
.
gp_layer_kwargs
)
config
.
update
(
super
(
GaussianProcessClassificationHead
,
self
).
get_config
())
return
config
def
extract_gp_layer_kwargs
(
kwargs
):
"""Extracts Gaussian process layer configs from a given kwarg."""
return
dict
(
num_inducing
=
kwargs
.
pop
(
"num_inducing"
,
1024
),
normalize_input
=
kwargs
.
pop
(
"normalize_input"
,
True
),
gp_cov_momentum
=
kwargs
.
pop
(
"gp_cov_momentum"
,
0.999
),
gp_cov_ridge_penalty
=
kwargs
.
pop
(
"gp_cov_ridge_penalty"
,
1e-6
),
scale_random_features
=
kwargs
.
pop
(
"scale_random_features"
,
False
),
l2_regularization
=
kwargs
.
pop
(
"l2_regularization"
,
0.
),
gp_cov_likelihood
=
kwargs
.
pop
(
"gp_cov_likelihood"
,
"gaussian"
),
return_gp_cov
=
kwargs
.
pop
(
"return_gp_cov"
,
True
),
return_random_features
=
kwargs
.
pop
(
"return_random_features"
,
False
),
use_custom_random_features
=
kwargs
.
pop
(
"use_custom_random_features"
,
True
),
custom_random_features_initializer
=
kwargs
.
pop
(
"custom_random_features_initializer"
,
"random_normal"
),
custom_random_features_activation
=
kwargs
.
pop
(
"custom_random_features_activation"
,
None
))
def
extract_spec_norm_kwargs
(
kwargs
):
"""Extracts spectral normalization configs from a given kwarg."""
return
dict
(
iteration
=
kwargs
.
pop
(
"iteration"
,
1
),
norm_multiplier
=
kwargs
.
pop
(
"norm_multiplier"
,
.
99
))
official/nlp/modeling/layers/cls_head_test.py
View file @
572fafb8
...
@@ -58,5 +58,56 @@ class MultiClsHeadsTest(tf.test.TestCase):
...
@@ -58,5 +58,56 @@ class MultiClsHeadsTest(tf.test.TestCase):
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
self
.
assertAllEqual
(
test_layer
.
get_config
(),
new_layer
.
get_config
())
class
GaussianProcessClassificationHead
(
tf
.
test
.
TestCase
):
def
setUp
(
self
):
super
().
setUp
()
self
.
spec_norm_kwargs
=
dict
(
norm_multiplier
=
1.
,)
self
.
gp_layer_kwargs
=
dict
(
num_inducing
=
512
)
def
test_layer_invocation
(
self
):
test_layer
=
cls_head
.
GaussianProcessClassificationHead
(
inner_dim
=
5
,
num_classes
=
2
,
use_spec_norm
=
True
,
use_gp_layer
=
True
,
initializer
=
"zeros"
,
**
self
.
spec_norm_kwargs
,
**
self
.
gp_layer_kwargs
)
features
=
tf
.
zeros
(
shape
=
(
2
,
10
,
10
),
dtype
=
tf
.
float32
)
output
,
_
=
test_layer
(
features
)
self
.
assertAllClose
(
output
,
[[
0.
,
0.
],
[
0.
,
0.
]])
self
.
assertSameElements
(
test_layer
.
checkpoint_items
.
keys
(),
[
"pooler_dense"
])
def
test_layer_serialization
(
self
):
layer
=
cls_head
.
GaussianProcessClassificationHead
(
inner_dim
=
5
,
num_classes
=
2
,
use_spec_norm
=
True
,
use_gp_layer
=
True
,
**
self
.
spec_norm_kwargs
,
**
self
.
gp_layer_kwargs
)
new_layer
=
cls_head
.
GaussianProcessClassificationHead
.
from_config
(
layer
.
get_config
())
# If the serialization was successful, the new config should match the old.
self
.
assertAllEqual
(
layer
.
get_config
(),
new_layer
.
get_config
())
def
test_sngp_kwargs_serialization
(
self
):
"""Tests if SNGP-specific kwargs are added during serialization."""
layer
=
cls_head
.
GaussianProcessClassificationHead
(
inner_dim
=
5
,
num_classes
=
2
,
use_spec_norm
=
True
,
use_gp_layer
=
True
,
**
self
.
spec_norm_kwargs
,
**
self
.
gp_layer_kwargs
)
layer_config
=
layer
.
get_config
()
# The config value should equal to those defined in setUp().
self
.
assertEqual
(
layer_config
[
"norm_multiplier"
],
1.
)
self
.
assertEqual
(
layer_config
[
"num_inducing"
],
512
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment