"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "2a4b9e09c02201ec5d45aa8cf49b667d04a5279e"
Unverified Commit 461ae868 authored by Jay Yip's avatar Jay Yip Committed by GitHub
Browse files

Fix tf boolean mask in graph mode (#6741)

parent 925f34bb
...@@ -137,7 +137,7 @@ class TFCausalLanguageModelingLoss: ...@@ -137,7 +137,7 @@ class TFCausalLanguageModelingLoss:
) )
# make sure only labels that are not equal to -100 # make sure only labels that are not equal to -100
# are taken into account as loss # are taken into account as loss
active_loss = tf.reshape(labels, (-1,)) != -100 active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
return loss_fn(labels, reduced_logits) return loss_fn(labels, reduced_logits)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment