Commit fd3e8141 authored by Jeremiah Liu's avatar Jeremiah Liu Committed by A. Unique TensorFlower
Browse files

Adds `mean_field_logits` function to GP layer.

PiperOrigin-RevId: 366175727
parent 80c50b6a
......@@ -13,14 +13,7 @@
# limitations under the License.
# Lint as: python3
"""Definitions for random feature Gaussian process layer.
## References:
[1]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel
Machines. In _Neural Information Processing Systems_, 2007.
https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf
"""
"""Definitions for random feature Gaussian process layer."""
import math
import tensorflow as tf
......@@ -29,7 +22,7 @@ _SUPPORTED_LIKELIHOOD = ('binary_logistic', 'poisson', 'gaussian')
class RandomFeatureGaussianProcess(tf.keras.layers.Layer):
"""Gaussian process layer with random feature approximation.
"""Gaussian process layer with random feature approximation [1].
During training, the model updates the maximum a posteriori (MAP) logits
estimates and posterior precision matrix using minibatch statistics. During
......@@ -47,6 +40,10 @@ class RandomFeatureGaussianProcess(tf.keras.layers.Layer):
A linear kernel can also be specified by setting gp_kernel_type='linear' and
`use_custom_random_features=True`.
[1]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel
Machines. In _Neural Information Processing Systems_, 2007.
https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf
Attributes:
units: (int) The dimensionality of layer.
num_inducing: (int) The number of random features for the approximation.
......@@ -458,3 +455,41 @@ class LaplaceRandomFeatureCovariance(tf.keras.layers.Layer):
else:
# Return covariance estimate during inference.
return self.compute_predictive_covariance(gp_feature=inputs)
def mean_field_logits(logits, covariance_matrix=None, mean_field_factor=1.):
"""Adjust the model logits so its softmax approximates the posterior mean [1].
[1]: Zhiyun Lu, Eugene Ie, Fei Sha. Uncertainty Estimation with Infinitesimal
Jackknife. _arXiv preprint arXiv:2006.07584_, 2020.
https://arxiv.org/abs/2006.07584
Arguments:
logits: A float tensor of shape (batch_size, num_classes).
covariance_matrix: The covariance matrix of shape (batch_size, batch_size).
If None then it assumes the covariance_matrix is an identity matrix.
mean_field_factor: The scale factor for mean-field approximation, used to
adjust the influence of posterior variance in posterior mean
approximation. If covariance_matrix=None then it is used as the
temperature parameter for temperature scaling.
Returns:
Tensor of adjusted logits, shape (batch_size, num_classes).
"""
if mean_field_factor is None or mean_field_factor < 0:
return logits
# Compute standard deviation.
if covariance_matrix is None:
variances = 1.
else:
variances = tf.linalg.diag_part(covariance_matrix)
# Compute scaling coefficient for mean-field approximation.
logits_scale = tf.sqrt(1. + variances * mean_field_factor)
if len(logits.shape) > 1:
# Cast logits_scale to compatible dimension.
logits_scale = tf.expand_dims(logits_scale, axis=-1)
return logits / logits_scale
......@@ -223,5 +223,46 @@ class GaussianProcessTest(tf.test.TestCase, parameterized.TestCase):
self.assertAllClose(gp_covmat, gp_covmat_new, atol=1e-4)
class MeanFieldLogitsTest(tf.test.TestCase):
def testMeanFieldLogitsLikelihood(self):
"""Tests if scaling is correct under different likelihood."""
batch_size = 10
num_classes = 12
variance = 1.5
mean_field_factor = 2.
rng = np.random.RandomState(0)
tf.random.set_seed(1)
logits = rng.randn(batch_size, num_classes)
covmat = tf.linalg.diag([variance] * batch_size)
logits_logistic = gaussian_process.mean_field_logits(
logits, covmat, mean_field_factor=mean_field_factor)
self.assertAllClose(logits_logistic, logits / 2., atol=1e-4)
def testMeanFieldLogitsTemperatureScaling(self):
"""Tests using mean_field_logits as temperature scaling method."""
batch_size = 10
num_classes = 12
rng = np.random.RandomState(0)
tf.random.set_seed(1)
logits = rng.randn(batch_size, num_classes)
# Test if there's no change to logits when mean_field_factor < 0.
logits_no_change = gaussian_process.mean_field_logits(
logits, covariance_matrix=None, mean_field_factor=-1)
# Test if mean_field_logits functions as a temperature scaling method when
# mean_field_factor > 0, with temperature = sqrt(1. + mean_field_factor).
logits_scale_by_two = gaussian_process.mean_field_logits(
logits, covariance_matrix=None, mean_field_factor=3.)
self.assertAllClose(logits_no_change, logits, atol=1e-4)
self.assertAllClose(logits_scale_by_two, logits / 2., atol=1e-4)
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment