Commit 1b59b57b authored by thomwolf's avatar thomwolf
Browse files

ignore_index equal -100 in T5 model

parent 569da80c
......@@ -905,7 +905,7 @@ class T5WithLMHeadModel(T5PreTrainedModel):
if lm_labels is not None:
shift_logits = lm_logits[..., :-1, :].contiguous()
shift_labels = lm_labels[..., 1:].contiguous()
loss_fct = CrossEntropyLoss(ignore_index=-1)
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
decoder_outputs = (
loss,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment