Unverified Commit 461ae868 authored by Jay Yip's avatar Jay Yip Committed by GitHub
Browse files

Fix tf boolean mask in graph mode (#6741)

parent 925f34bb
......@@ -137,7 +137,7 @@ class TFCausalLanguageModelingLoss:
)
# make sure only labels that are not equal to -100
# are taken into account as loss
active_loss = tf.reshape(labels, (-1,)) != -100
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment