Unverified Commit 9a2dabae authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix dtype issue in TF BART (#15178)

parent 0167edc8
......@@ -1383,7 +1383,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
if inputs["labels"] is not None:
inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100),
tf.cast(tf.fill(shape_list(inputs["labels"]), -100), inputs["labels"].dtype),
inputs["labels"],
)
inputs["use_cache"] = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment