Unverified Commit 9a2dabae authored by Matt's avatar Matt Committed by GitHub
Browse files

Fix dtype issue in TF BART (#15178)

parent 0167edc8
...@@ -1383,7 +1383,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode ...@@ -1383,7 +1383,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode
if inputs["labels"] is not None: if inputs["labels"] is not None:
inputs["labels"] = tf.where( inputs["labels"] = tf.where(
inputs["labels"] == self.config.pad_token_id, inputs["labels"] == self.config.pad_token_id,
tf.fill(shape_list(inputs["labels"]), -100), tf.cast(tf.fill(shape_list(inputs["labels"]), -100), inputs["labels"].dtype),
inputs["labels"], inputs["labels"],
) )
inputs["use_cache"] = False inputs["use_cache"] = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment