"vscode:/vscode.git/clone" did not exist on "3c6035aa8ad5c429299630f8c6f673896fbd7b16"
Unverified Commit 3be2d048 authored by idoh's avatar idoh Committed by GitHub
Browse files

fix consistency CrossEntropyLoss in modeling_bart (#6265)

parent c72f9c90
......@@ -1040,7 +1040,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
masked_lm_loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels?
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
......@@ -1179,7 +1179,8 @@ class BartForSequenceClassification(PretrainedBartModel):
loss = None
if labels is not None:
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict:
output = (logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment