Commit fab47e9e authored by Jeremiah Liu's avatar Jeremiah Liu Committed by A. Unique TensorFlower
Browse files

Disable temperature scaling during training.

For the `GaussianProcessClassificationHead`, the temperature scaling needs to be disabled during training to avoid unexpected modification to the learning rate, which harms model quality. (Unfortunately, this seems to require adding `training` to the `call` method).

Also set the default of `gp_cov_ridge_penalty` in `RandomFeatureGaussianProcess` to 1 to be consistent with that in the `GaussianProcessClassificationHead`.

PiperOrigin-RevId: 366917075
parent a9fe9ba9
...@@ -252,11 +252,15 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -252,11 +252,15 @@ class GaussianProcessClassificationHead(ClassificationHead):
name="logits", name="logits",
**self.gp_layer_kwargs) **self.gp_layer_kwargs)
def call(self, features, return_covmat=False): def call(self, features, training=False, return_covmat=False):
"""Returns model output. """Returns model output.
Dring training, the model returns raw logits. During evaluation, the model
returns uncertainty adjusted logits, and (optionally) the covariance matrix.
Arguments: Arguments:
features: A tensor of input features, shape (batch_size, feature_dim). features: A tensor of input features, shape (batch_size, feature_dim).
training: Whether the model is in training mode.
return_covmat: Whether the model should also return covariance matrix if return_covmat: Whether the model should also return covariance matrix if
`use_gp_layer=True`. During training, it is recommended to set `use_gp_layer=True`. During training, it is recommended to set
`return_covmat=False` to be compatible with the standard Keras pipelines `return_covmat=False` to be compatible with the standard Keras pipelines
...@@ -276,13 +280,13 @@ class GaussianProcessClassificationHead(ClassificationHead): ...@@ -276,13 +280,13 @@ class GaussianProcessClassificationHead(ClassificationHead):
else: else:
covmat = None covmat = None
# Computes the uncertainty-adjusted logits. # Computes the uncertainty-adjusted logits during evaluation.
if not training:
logits = gaussian_process.mean_field_logits( logits = gaussian_process.mean_field_logits(
logits, covmat, mean_field_factor=self.temperature) logits, covmat, mean_field_factor=self.temperature)
if return_covmat and covmat is not None: if return_covmat and covmat is not None:
return logits, covmat return logits, covmat
return logits return logits
def reset_covariance_matrix(self): def reset_covariance_matrix(self):
......
...@@ -134,7 +134,6 @@ class GaussianProcessClassificationHead(tf.test.TestCase, ...@@ -134,7 +134,6 @@ class GaussianProcessClassificationHead(tf.test.TestCase,
num_classes=num_classes, num_classes=num_classes,
use_spec_norm=True, use_spec_norm=True,
use_gp_layer=use_gp_layer, use_gp_layer=use_gp_layer,
initializer="zeros",
**self.spec_norm_kwargs, **self.spec_norm_kwargs,
**self.gp_layer_kwargs) **self.gp_layer_kwargs)
...@@ -149,6 +148,23 @@ class GaussianProcessClassificationHead(tf.test.TestCase, ...@@ -149,6 +148,23 @@ class GaussianProcessClassificationHead(tf.test.TestCase,
self.assertIsInstance(outputs, tf.Tensor) self.assertIsInstance(outputs, tf.Tensor)
self.assertEqual(outputs.shape, (batch_size, num_classes)) self.assertEqual(outputs.shape, (batch_size, num_classes))
def test_sngp_train_logits(self):
"""Checks if temperature scaling is disabled during training."""
features = tf.zeros(shape=(5, 10, 10), dtype=tf.float32)
gp_layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5, num_classes=2)
# Without temperature.
gp_layer.temperature = None
outputs_no_temp = gp_layer(features, training=True)
# With temperature.
gp_layer.temperature = 10.
outputs_with_temp = gp_layer(features, training=True)
self.assertAllEqual(outputs_no_temp, outputs_with_temp)
def test_layer_serialization(self): def test_layer_serialization(self):
layer = cls_head.GaussianProcessClassificationHead( layer = cls_head.GaussianProcessClassificationHead(
inner_dim=5, inner_dim=5,
......
...@@ -62,7 +62,7 @@ class RandomFeatureGaussianProcess(tf.keras.layers.Layer): ...@@ -62,7 +62,7 @@ class RandomFeatureGaussianProcess(tf.keras.layers.Layer):
gp_kernel_scale_trainable=False, gp_kernel_scale_trainable=False,
gp_output_bias_trainable=False, gp_output_bias_trainable=False,
gp_cov_momentum=0.999, gp_cov_momentum=0.999,
gp_cov_ridge_penalty=1e-6, gp_cov_ridge_penalty=1.,
scale_random_features=True, scale_random_features=True,
use_custom_random_features=True, use_custom_random_features=True,
custom_random_features_initializer=None, custom_random_features_initializer=None,
...@@ -292,7 +292,7 @@ class LaplaceRandomFeatureCovariance(tf.keras.layers.Layer): ...@@ -292,7 +292,7 @@ class LaplaceRandomFeatureCovariance(tf.keras.layers.Layer):
def __init__(self, def __init__(self,
momentum=0.999, momentum=0.999,
ridge_penalty=1e-6, ridge_penalty=1.,
likelihood='gaussian', likelihood='gaussian',
dtype=None, dtype=None,
name='laplace_covariance'): name='laplace_covariance'):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment