Adds `call()` method to `GaussianProcessClassificationHead`.
This change allows `GaussianProcessClassificationHead` to output only the predictive logits during training and evaluation (instead of outputting a tuple `(logits, covmat)`). The goal is to make the layer more compatible with `SentencePredictionTask` and Keras' `model.fit()` API. PiperOrigin-RevId: 366891298
Showing
Please register or sign in to comment