Unverified Commit c28bc80b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Generalize problem_type to all sequence classification models (#14180)

* Generalize problem_type to all classification models

* Missing import

* Deberta BC and fix tests

* Fix template

* Missing imports

* Revert change to reformer test

* Fix style
parent 4ab6a4a0
...@@ -22,7 +22,7 @@ from typing import Optional, Tuple ...@@ -22,7 +22,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1475,14 +1475,26 @@ class BartForSequenceClassification(BartPretrainedModel): ...@@ -1475,14 +1475,26 @@ class BartForSequenceClassification(BartPretrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.num_labels == 1: if self.config.problem_type is None:
# regression if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.config.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -23,7 +23,7 @@ from typing import Optional, Tuple ...@@ -23,7 +23,7 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -2680,14 +2680,26 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel): ...@@ -2680,14 +2680,26 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.num_labels == 1: if self.config.problem_type is None:
# regression if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.config.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -20,7 +20,7 @@ from typing import Tuple ...@@ -20,7 +20,7 @@ from typing import Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutput
...@@ -690,14 +690,26 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel): ...@@ -690,14 +690,26 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) if self.num_labels == 1:
else: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict: if not return_dict:
output = (pooled_logits,) + transformer_outputs[2:] output = (pooled_logits,) + transformer_outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -19,7 +19,7 @@ from collections.abc import Sequence ...@@ -19,7 +19,7 @@ from collections.abc import Sequence
import torch import torch
from torch import _softmax_backward_data, nn from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -1194,31 +1194,46 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel): ...@@ -1194,31 +1194,46 @@ class DebertaForSequenceClassification(DebertaPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# regression task if self.num_labels == 1:
loss_fn = nn.MSELoss() # regression task
logits = logits.view(-1).to(labels.dtype) loss_fn = nn.MSELoss()
loss = loss_fn(logits, labels.view(-1)) logits = logits.view(-1).to(labels.dtype)
elif labels.dim() == 1 or labels.size(-1) == 1: loss = loss_fn(logits, labels.view(-1))
label_index = (labels >= 0).nonzero() elif labels.dim() == 1 or labels.size(-1) == 1:
labels = labels.long() label_index = (labels >= 0).nonzero()
if label_index.size(0) > 0: labels = labels.long()
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) if label_index.size(0) > 0:
labels = torch.gather(labels, 0, label_index.view(-1)) labeled_logits = torch.gather(
loss_fct = CrossEntropyLoss() logits, 0, label_index.expand(label_index.size(0), logits.size(1))
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) )
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
else:
loss = torch.tensor(0).to(logits)
else: else:
loss = torch.tensor(0).to(logits) log_softmax = nn.LogSoftmax(-1)
else: loss = -((log_softmax(logits) * labels).sum(-1)).mean()
log_softmax = nn.LogSoftmax(-1) elif self.config.problem_type == "regression":
loss = -((log_softmax(logits) * labels).sum(-1)).mean() loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
else:
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
) )
@add_start_docstrings( @add_start_docstrings(
......
...@@ -20,7 +20,7 @@ from collections.abc import Sequence ...@@ -20,7 +20,7 @@ from collections.abc import Sequence
import numpy as np import numpy as np
import torch import torch
from torch import _softmax_backward_data, nn from torch import _softmax_backward_data, nn
from torch.nn import CrossEntropyLoss, LayerNorm from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, LayerNorm, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -1304,31 +1304,46 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel): ...@@ -1304,31 +1304,46 @@ class DebertaV2ForSequenceClassification(DebertaV2PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# regression task if self.num_labels == 1:
loss_fn = nn.MSELoss() # regression task
logits = logits.view(-1).to(labels.dtype) loss_fn = nn.MSELoss()
loss = loss_fn(logits, labels.view(-1)) logits = logits.view(-1).to(labels.dtype)
elif labels.dim() == 1 or labels.size(-1) == 1: loss = loss_fn(logits, labels.view(-1))
label_index = (labels >= 0).nonzero() elif labels.dim() == 1 or labels.size(-1) == 1:
labels = labels.long() label_index = (labels >= 0).nonzero()
if label_index.size(0) > 0: labels = labels.long()
labeled_logits = torch.gather(logits, 0, label_index.expand(label_index.size(0), logits.size(1))) if label_index.size(0) > 0:
labels = torch.gather(labels, 0, label_index.view(-1)) labeled_logits = torch.gather(
loss_fct = CrossEntropyLoss() logits, 0, label_index.expand(label_index.size(0), logits.size(1))
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1)) )
labels = torch.gather(labels, 0, label_index.view(-1))
loss_fct = CrossEntropyLoss()
loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
else:
loss = torch.tensor(0).to(logits)
else: else:
loss = torch.tensor(0).to(logits) log_softmax = nn.LogSoftmax(-1)
else: loss = -((log_softmax(logits) * labels).sum(-1)).mean()
log_softmax = nn.LogSoftmax(-1) elif self.config.problem_type == "regression":
loss = -((log_softmax(logits) * labels).sum(-1)).mean() loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
else:
return SequenceClassifierOutput( return SequenceClassifierOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
) )
@add_start_docstrings( @add_start_docstrings(
......
...@@ -23,7 +23,7 @@ import torch ...@@ -23,7 +23,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import is_scipy_available from ...file_utils import is_scipy_available
...@@ -927,14 +927,26 @@ class FNetForSequenceClassification(FNetPreTrainedModel): ...@@ -927,14 +927,26 @@ class FNetForSequenceClassification(FNetPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
if version.parse(torch.__version__) >= version.parse("1.6"): if version.parse(torch.__version__) >= version.parse("1.6"):
...@@ -1406,14 +1406,26 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel): ...@@ -1406,14 +1406,26 @@ class GPT2ForSequenceClassification(GPT2PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) if self.num_labels == 1:
else: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict: if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:] output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -21,7 +21,7 @@ from typing import Tuple ...@@ -21,7 +21,7 @@ from typing import Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -895,14 +895,26 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel): ...@@ -895,14 +895,26 @@ class GPTNeoForSequenceClassification(GPTNeoPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) if self.num_labels == 1:
else: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict: if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:] output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -19,7 +19,7 @@ from typing import Tuple ...@@ -19,7 +19,7 @@ from typing import Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -931,14 +931,26 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel): ...@@ -931,14 +931,26 @@ class GPTJForSequenceClassification(GPTJPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) if self.num_labels == 1:
else: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict: if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:] output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -22,7 +22,7 @@ import math ...@@ -22,7 +22,7 @@ import math
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu from ...activations import gelu
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -1025,14 +1025,26 @@ class IBertForSequenceClassification(IBertPreTrainedModel): ...@@ -1025,14 +1025,26 @@ class IBertForSequenceClassification(IBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -20,7 +20,7 @@ import math ...@@ -20,7 +20,7 @@ import math
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
...@@ -1059,14 +1059,26 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel): ...@@ -1059,14 +1059,26 @@ class LayoutLMForSequenceClassification(LayoutLMPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -20,7 +20,7 @@ import math ...@@ -20,7 +20,7 @@ import math
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1076,14 +1076,26 @@ class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel): ...@@ -1076,14 +1076,26 @@ class LayoutLMv2ForSequenceClassification(LayoutLMv2PreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple ...@@ -23,7 +23,7 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -2536,9 +2536,26 @@ class LEDForSequenceClassification(LEDPreTrainedModel): ...@@ -2536,9 +2536,26 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
loss_fct = CrossEntropyLoss() if self.config.problem_type is None:
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -21,7 +21,7 @@ from typing import Optional, Tuple ...@@ -21,7 +21,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1475,14 +1475,26 @@ class MBartForSequenceClassification(MBartPreTrainedModel): ...@@ -1475,14 +1475,26 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.num_labels == 1: if self.config.problem_type is None:
# regression if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.config.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -25,7 +25,7 @@ from typing import Optional, Tuple ...@@ -25,7 +25,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1525,14 +1525,26 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel): ...@@ -1525,14 +1525,26 @@ class MegatronBertForSequenceClassification(MegatronBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -20,7 +20,7 @@ import math ...@@ -20,7 +20,7 @@ import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -736,14 +736,26 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel): ...@@ -736,14 +736,26 @@ class MPNetForSequenceClassification(MPNetPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -24,7 +24,7 @@ from typing import Optional, Tuple ...@@ -24,7 +24,7 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu_new, silu from ...activations import gelu_new, silu
from ...file_utils import ( from ...file_utils import (
...@@ -823,14 +823,26 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel): ...@@ -823,14 +823,26 @@ class OpenAIGPTForSequenceClassification(OpenAIGPTPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) if self.num_labels == 1:
else: loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict: if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:] output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -21,7 +21,7 @@ import os ...@@ -21,7 +21,7 @@ import os
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1220,14 +1220,26 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel): ...@@ -1220,14 +1220,26 @@ class RemBertForSequenceClassification(RemBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -23,7 +23,7 @@ import numpy as np ...@@ -23,7 +23,7 @@ import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1287,14 +1287,26 @@ class RoFormerForSequenceClassification(RoFormerPreTrainedModel): ...@@ -1287,14 +1287,26 @@ class RoFormerForSequenceClassification(RoFormerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -24,7 +24,7 @@ from typing import Optional, Tuple ...@@ -24,7 +24,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1532,14 +1532,26 @@ class TapasForSequenceClassification(TapasPreTrainedModel): ...@@ -1532,14 +1532,26 @@ class TapasForSequenceClassification(TapasPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
else: loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment