"vscode:/vscode.git/clone" did not exist on "79d08997f163ff36e3772ceac96937ebde3bb56f"
Commit 7b3bb8c0 authored by Malte Pietsch's avatar Malte Pietsch
Browse files

fix typo in input for masked lm loss function

parent 257a3513
......@@ -678,7 +678,7 @@ class BertForPreTraining(PreTrainedBertModel):
if masked_lm_labels is not None and next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels(-1))
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment