• Jeremiah Liu's avatar
    Adds `call()` method to `GaussianProcessClassificationHead`. · 2ae06c8a
    Jeremiah Liu authored
    This change allows `GaussianProcessClassificationHead` to output only the predictive logits during training and evaluation (instead of outputting a tuple  `(logits, covmat)`). The goal is to make the layer more compatible with `SentencePredictionTask` and Keras' `model.fit()` API.
    
    PiperOrigin-RevId: 366891298
    2ae06c8a
cls_head.py 11.6 KB