Commit ac83d2b5 authored by Jeremiah Liu's avatar Jeremiah Liu Committed by A. Unique TensorFlower
Browse files

Implements `GaussianProcessClassificationHead`.

PiperOrigin-RevId: 364226289
parent a6bbd3fa
......@@ -18,6 +18,9 @@ import tensorflow as tf
from official.modeling import tf_utils
from official.nlp.modeling.layers import gaussian_process
from official.nlp.modeling.layers import spectral_normalization
class ClassificationHead(tf.keras.layers.Layer):
"""Pooling head for sentence-level classification tasks."""
......@@ -160,3 +163,110 @@ class MultiClsHeads(tf.keras.layers.Layer):
items = {self.dense.name: self.dense}
items.update({v.name: v for v in self.out_projs})
return items
class GaussianProcessClassificationHead(ClassificationHead):
"""Gaussian process-based pooling head for sentence classification.
This class implements a classifier head for BERT encoder that is based on the
spectral-normalized neural Gaussian process (SNGP) [1]. SNGP is a simple
method to improve a neural network's uncertainty quantification ability
without sacrificing accuracy or lantency. It applies spectral normalization to
the hidden pooler layer, and then replaces the dense output layer with a
Gaussian process.
[1]: Jeremiah Liu et al. Simple and Principled Uncertainty Estimation with
Deterministic Deep Learning via Distance Awareness.
In _Neural Information Processing Systems_, 2020.
https://arxiv.org/abs/2006.10108
"""
def __init__(self,
inner_dim,
num_classes,
cls_token_idx=0,
activation="tanh",
dropout_rate=0.0,
initializer="glorot_uniform",
use_spec_norm=True,
use_gp_layer=True,
**kwargs):
"""Initializes the `GaussianProcessClassificationHead`.
Args:
inner_dim: The dimensionality of inner projection layer.
num_classes: Number of output classes.
cls_token_idx: The index inside the sequence to pool.
activation: Dense layer activation.
dropout_rate: Dropout probability.
initializer: Initializer for dense layer kernels.
use_spec_norm: Whether to apply spectral normalization to pooler layer.
use_gp_layer: Whether to use Gaussian process as the output layer.
**kwargs: Additional keyword arguments.
"""
# Collects spectral normalization and Gaussian process args from kwargs.
self.use_spec_norm = use_spec_norm
self.use_gp_layer = use_gp_layer
self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs)
self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs)
super().__init__(
inner_dim=inner_dim,
num_classes=num_classes,
cls_token_idx=cls_token_idx,
activation=activation,
dropout_rate=dropout_rate,
initializer=initializer,
**kwargs)
# Applies spectral normalization to the pooler layer.
if use_spec_norm:
self.dense = spectral_normalization.SpectralNormalization(
self.dense, inhere_layer_name=True, **self.spec_norm_kwargs)
# Replace Dense output layer with the Gaussian process layer.
if use_gp_layer:
self.out_proj = gaussian_process.RandomFeatureGaussianProcess(
self.num_classes,
kernel_initializer=self.initializer,
name="logits",
**self.gp_layer_kwargs)
def get_config(self):
config = dict(
use_spec_norm=self.use_spec_norm, use_gp_layer=self.use_gp_layer)
config.update(self.spec_norm_kwargs)
config.update(self.gp_layer_kwargs)
config.update(super(GaussianProcessClassificationHead, self).get_config())
return config
def extract_gp_layer_kwargs(kwargs):
"""Extracts Gaussian process layer configs from a given kwarg."""
return dict(
num_inducing=kwargs.pop("num_inducing", 1024),
normalize_input=kwargs.pop("normalize_input", True),
gp_cov_momentum=kwargs.pop("gp_cov_momentum", 0.999),
gp_cov_ridge_penalty=kwargs.pop("gp_cov_ridge_penalty", 1e-6),
scale_random_features=kwargs.pop("scale_random_features", False),
l2_regularization=kwargs.pop("l2_regularization", 0.),
gp_cov_likelihood=kwargs.pop("gp_cov_likelihood", "gaussian"),
return_gp_cov=kwargs.pop("return_gp_cov", True),
return_random_features=kwargs.pop("return_random_features", False),
use_custom_random_features=kwargs.pop("use_custom_random_features", True),
custom_random_features_initializer=kwargs.pop(
"custom_random_features_initializer", "random_normal"),
custom_random_features_activation=kwargs.pop(
"custom_random_features_activation", None))
def extract_spec_norm_kwargs(kwargs):
"""Extracts spectral normalization configs from a given kwarg."""
return dict(
iteration=kwargs.pop("iteration", 1),
norm_multiplier=kwargs.pop("norm_multiplier", .99))
......@@ -58,5 +58,56 @@ class MultiClsHeadsTest(tf.test.TestCase):
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())
class GaussianProcessClassificationHead(tf.test.TestCase):
def setUp(self):
super().setUp()
self.spec_norm_kwargs = dict(norm_multiplier=1.,)
self.gp_layer_kwargs = dict(num_inducing=512)
def test_layer_invocation(self):
test_layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5,
num_classes=2,
use_spec_norm=True,
use_gp_layer=True,
initializer="zeros",
**self.spec_norm_kwargs,
**self.gp_layer_kwargs)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
output, _ = test_layer(features)
self.assertAllClose(output, [[0., 0.], [0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"])
def test_layer_serialization(self):
layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5,
num_classes=2,
use_spec_norm=True,
use_gp_layer=True,
**self.spec_norm_kwargs,
**self.gp_layer_kwargs)
new_layer = cls_head.GaussianProcessClassificationHead.from_config(
layer.get_config())
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(layer.get_config(), new_layer.get_config())
def test_sngp_kwargs_serialization(self):
"""Tests if SNGP-specific kwargs are added during serialization."""
layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5,
num_classes=2,
use_spec_norm=True,
use_gp_layer=True,
**self.spec_norm_kwargs,
**self.gp_layer_kwargs)
layer_config = layer.get_config()
# The config value should equal to those defined in setUp().
self.assertEqual(layer_config["norm_multiplier"], 1.)
self.assertEqual(layer_config["num_inducing"], 512)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment