Unverified Commit f064e0a4 authored by Michal Szutenberg's avatar Michal Szutenberg Committed by GitHub
Browse files

Cast logits to fp32 at the end of TF_T5 (#12332)

This change enables tf.keras.mixed_precision with bf16
parent b7439675
......@@ -1407,6 +1407,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
else:
logits = self.lm_head(sequence_output)
logits = tf.cast(logits, tf.float32)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not inputs["return_dict"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment