"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "8f3b4a1d5bd97045541c43179efe8cd9c58adb76"
Unverified Commit f064e0a4 authored by Michal Szutenberg's avatar Michal Szutenberg Committed by GitHub
Browse files

Cast logits to fp32 at the end of TF_T5 (#12332)

This change enables tf.keras.mixed_precision with bf16
parent b7439675
...@@ -1407,6 +1407,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1407,6 +1407,8 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
else: else:
logits = self.lm_head(sequence_output) logits = self.lm_head(sequence_output)
logits = tf.cast(logits, tf.float32)
loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits) loss = None if inputs["labels"] is None else self.compute_loss(inputs["labels"], logits)
if not inputs["return_dict"]: if not inputs["return_dict"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment