• Jeremiah Liu's avatar
    Disable temperature scaling during training. · ff3ed4cc
    Jeremiah Liu authored
    For the `GaussianProcessClassificationHead`, the temperature scaling needs to be disabled during training to avoid unexpected modification to the learning rate, which harms model quality. (Unfortunately, this seems to require adding `training` to the `call` method).
    
    Also set the default of `gp_cov_ridge_penalty` in `RandomFeatureGaussianProcess` to 1 to be consistent with that in the `GaussianProcessClassificationHead`.
    
    PiperOrigin-RevId: 366917075
    ff3ed4cc
gaussian_process.py 20 KB