Unverified Commit d49d3cf6 authored by calpt's avatar calpt Committed by GitHub
Browse files

Use MSELoss in (M)BartForSequenceClassification (#11178)

parent f243a5ec
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1437,8 +1437,13 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1437,8 +1437,13 @@ class BartForSequenceClassification(BartPretrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() if self.config.num_labels == 1:
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) # regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -22,7 +22,7 @@ import torch ...@@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1442,8 +1442,13 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1442,8 +1442,13 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() if self.config.num_labels == 1:
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) # regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment