Unverified Commit 3be2d048 authored by idoh's avatar idoh Committed by GitHub
Browse files

fix consistency CrossEntropyLoss in modeling_bart (#6265)

parent c72f9c90
...@@ -1040,7 +1040,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1040,7 +1040,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
masked_lm_loss = None masked_lm_loss = None
if labels is not None: if labels is not None:
loss_fct = nn.CrossEntropyLoss() loss_fct = CrossEntropyLoss()
# TODO(SS): do we need to ignore pad tokens in labels? # TODO(SS): do we need to ignore pad tokens in labels?
masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
...@@ -1179,7 +1179,8 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1179,7 +1179,8 @@ class BartForSequenceClassification(PretrainedBartModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss = F.cross_entropy(logits.view(-1, self.config.num_labels), labels.view(-1)) loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment