"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d164867d90c7b352445aa7d4028a6ba156a70a77"
Unverified Commit 6a346f03 authored by Muennighoff's avatar Muennighoff Committed by GitHub
Browse files

fix typo (#9708)



* fix typo
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 4a20b7c4
...@@ -147,7 +147,7 @@ class TFCausalLanguageModelingLoss: ...@@ -147,7 +147,7 @@ class TFCausalLanguageModelingLoss:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy( loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE from_logits=True, reduction=tf.keras.losses.Reduction.NONE
) )
# make sure only labels that are not equal to -100 do not affect loss # make sure only labels that are not equal to -100 affect the loss
active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100) active_loss = tf.not_equal(tf.reshape(labels, (-1,)), -100)
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss) reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss) labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment