"tests/vscode:/vscode.git/clone" did not exist on "de4d71ea07b31c1bcef7ffccc3691f76658e291f"
Unverified Commit d49d3cf6 authored by calpt's avatar calpt Committed by GitHub
Browse files

Use MSELoss in (M)BartForSequenceClassification (#11178)

parent f243a5ec
......@@ -23,7 +23,7 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
......@@ -1437,6 +1437,11 @@ class BartForSequenceClassification(BartPretrainedModel):
loss = None
if labels is not None:
if self.config.num_labels == 1:
# regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
......
......@@ -22,7 +22,7 @@ import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
......@@ -1442,6 +1442,11 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
loss = None
if labels is not None:
if self.config.num_labels == 1:
# regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment