Commit 2ae06c8a authored by Jeremiah Liu's avatar Jeremiah Liu Committed by A. Unique TensorFlower
Browse files

Adds `call()` method to `GaussianProcessClassificationHead`.

This change allows `GaussianProcessClassificationHead` to output only the predictive logits during training and evaluation (instead of outputting a tuple  `(logits, covmat)`). The goal is to make the layer more compatible with `SentencePredictionTask` and Keras' `model.fit()` API.

PiperOrigin-RevId: 366891298
parent e21b3e9e
...@@ -204,6 +204,7 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -204,6 +204,7 @@ class GaussianProcessClassificationHead(ClassificationHead):
initializer="glorot_uniform", initializer="glorot_uniform",
use_spec_norm=True, use_spec_norm=True,
use_gp_layer=True, use_gp_layer=True,
temperature=None,
**kwargs): **kwargs):
"""Initializes the `GaussianProcessClassificationHead`. """Initializes the `GaussianProcessClassificationHead`.
...@@ -217,6 +218,9 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -217,6 +218,9 @@ class GaussianProcessClassificationHead(ClassificationHead):
initializer: Initializer for dense layer kernels. initializer: Initializer for dense layer kernels.
use_spec_norm: Whether to apply spectral normalization to pooler layer. use_spec_norm: Whether to apply spectral normalization to pooler layer.
use_gp_layer: Whether to use Gaussian process as the output layer. use_gp_layer: Whether to use Gaussian process as the output layer.
temperature: The temperature parameter to be used for mean-field
approximation during inference. If None then no mean-field adjustment is
applied.
**kwargs: Additional keyword arguments. **kwargs: Additional keyword arguments.
""" """
# Collects spectral normalization and Gaussian process args from kwargs. # Collects spectral normalization and Gaussian process args from kwargs.
...@@ -224,6 +228,7 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -224,6 +228,7 @@ class GaussianProcessClassificationHead(ClassificationHead):
self.use_gp_layer = use_gp_layer self.use_gp_layer = use_gp_layer
self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs) self.spec_norm_kwargs = extract_spec_norm_kwargs(kwargs)
self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs) self.gp_layer_kwargs = extract_gp_layer_kwargs(kwargs)
self.temperature = temperature
super().__init__( super().__init__(
inner_dim=inner_dim, inner_dim=inner_dim,
...@@ -247,12 +252,51 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -247,12 +252,51 @@ class GaussianProcessClassificationHead(ClassificationHead):
name="logits", name="logits",
**self.gp_layer_kwargs) **self.gp_layer_kwargs)
def call(self, features, return_covmat=False):
"""Returns model output.
Arguments:
features: A tensor of input features, shape (batch_size, feature_dim).
return_covmat: Whether the model should also return covariance matrix if
`use_gp_layer=True`. During training, it is recommended to set
`return_covmat=False` to be compatible with the standard Keras pipelines
(e.g., `model.fit()`).
Returns:
logits: Uncertainty-adjusted predictive logits, shape
(batch_size, num_classes).
covmat: (Optional) Covariance matrix, shape (batch_size, batch_size).
Returned only when return_covmat=True.
"""
logits = super().call(features)
# Extracts logits and covariance matrix from model output.
if self.use_gp_layer:
logits, covmat = logits
else:
covmat = None
# Computes the uncertainty-adjusted logits.
logits = gaussian_process.mean_field_logits(
logits, covmat, mean_field_factor=self.temperature)
if return_covmat and covmat is not None:
return logits, covmat
return logits
def reset_covariance_matrix(self):
"""Resets covariance matrix of the Gaussian process layer."""
if hasattr(self.out_proj, "reset_covariance_matrix"):
self.out_proj.reset_covariance_matrix()
def get_config(self): def get_config(self):
config = dict( config = dict(
use_spec_norm=self.use_spec_norm, use_gp_layer=self.use_gp_layer) use_spec_norm=self.use_spec_norm, use_gp_layer=self.use_gp_layer)
config.update(self.spec_norm_kwargs) config.update(self.spec_norm_kwargs)
config.update(self.gp_layer_kwargs) config.update(self.gp_layer_kwargs)
config["temperature"] = self.temperature
config.update(super(GaussianProcessClassificationHead, self).get_config()) config.update(super(GaussianProcessClassificationHead, self).get_config())
return config return config
...@@ -265,9 +309,9 @@ def extract_gp_layer_kwargs(kwargs): ...@@ -265,9 +309,9 @@ def extract_gp_layer_kwargs(kwargs):
num_inducing=kwargs.pop("num_inducing", 1024), num_inducing=kwargs.pop("num_inducing", 1024),
normalize_input=kwargs.pop("normalize_input", True), normalize_input=kwargs.pop("normalize_input", True),
gp_cov_momentum=kwargs.pop("gp_cov_momentum", 0.999), gp_cov_momentum=kwargs.pop("gp_cov_momentum", 0.999),
gp_cov_ridge_penalty=kwargs.pop("gp_cov_ridge_penalty", 1e-6), gp_cov_ridge_penalty=kwargs.pop("gp_cov_ridge_penalty", 1.),
scale_random_features=kwargs.pop("scale_random_features", False), scale_random_features=kwargs.pop("scale_random_features", False),
l2_regularization=kwargs.pop("l2_regularization", 0.), l2_regularization=kwargs.pop("l2_regularization", 1e-6),
gp_cov_likelihood=kwargs.pop("gp_cov_likelihood", "gaussian"), gp_cov_likelihood=kwargs.pop("gp_cov_likelihood", "gaussian"),
return_gp_cov=kwargs.pop("return_gp_cov", True), return_gp_cov=kwargs.pop("return_gp_cov", True),
return_random_features=kwargs.pop("return_random_features", False), return_random_features=kwargs.pop("return_random_features", False),
......
...@@ -115,11 +115,40 @@ class GaussianProcessClassificationHead(tf.test.TestCase, ...@@ -115,11 +115,40 @@ class GaussianProcessClassificationHead(tf.test.TestCase,
**self.spec_norm_kwargs, **self.spec_norm_kwargs,
**self.gp_layer_kwargs) **self.gp_layer_kwargs)
features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32) features = tf.zeros(shape=(2, 10, 10), dtype=tf.float32)
output, _ = test_layer(features) output = test_layer(features)
self.assertAllClose(output, [[0., 0.], [0., 0.]]) self.assertAllClose(output, [[0., 0.], [0., 0.]])
self.assertSameElements(test_layer.checkpoint_items.keys(), self.assertSameElements(test_layer.checkpoint_items.keys(),
["pooler_dense"]) ["pooler_dense"])
@parameterized.named_parameters(
("gp_layer_with_covmat", True, True),
("gp_layer_no_covmat", True, False),
("dense_layer_with_covmat", False, True),
("dense_layer_no_covmat", False, False))
def test_sngp_output_shape(self, use_gp_layer, return_covmat):
batch_size = 32
num_classes = 2
test_layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5,
num_classes=num_classes,
use_spec_norm=True,
use_gp_layer=use_gp_layer,
initializer="zeros",
**self.spec_norm_kwargs,
**self.gp_layer_kwargs)
features = tf.zeros(shape=(batch_size, 10, 10), dtype=tf.float32)
outputs = test_layer(features, return_covmat=return_covmat)
if use_gp_layer and return_covmat:
self.assertIsInstance(outputs, tuple)
self.assertEqual(outputs[0].shape, (batch_size, num_classes))
self.assertEqual(outputs[1].shape, (batch_size, batch_size))
else:
self.assertIsInstance(outputs, tf.Tensor)
self.assertEqual(outputs.shape, (batch_size, num_classes))
def test_layer_serialization(self): def test_layer_serialization(self):
layer = cls_head.GaussianProcessClassificationHead( layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5, inner_dim=5,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment