"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6af0854efa3693e0b38c936707966685ec3d0ae8"
Unverified Commit c40c7e21 authored by abhishek thakur's avatar abhishek thakur Committed by GitHub
Browse files

Add multi-class, multi-label and regression to transformers (#11012)



* add to  bert

* review comments

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* self.config.problem_type

* fix style

* fix

* fin

* fix

* update doc

* fix

* test

* Test more problem types

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix

* remove

* fix

* quality

* make fix-copies

* remove test
Co-authored-by: default avatarabhishek thakur <abhishekkrthakur@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 7c622482
...@@ -163,6 +163,14 @@ class PretrainedConfig(PushToHubMixin): ...@@ -163,6 +163,14 @@ class PretrainedConfig(PushToHubMixin):
typically for a classification task. typically for a classification task.
- **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the - **task_specific_params** (:obj:`Dict[str, Any]`, `optional`) -- Additional keyword arguments to store for the
current task. current task.
- **problem_type** (:obj:`str`, `optional`) -- Problem type for :obj:`XxxForSequenceClassification` models. Can
be one of (:obj:`"regression"`, :obj:`"single_label_classification"`, :obj:`"multi_label_classification"`).
Please note that this parameter is only available in the following models: `AlbertForSequenceClassification`,
`BertForSequenceClassification`, `BigBirdForSequenceClassification`, `ConvBertForSequenceClassification`,
`DistilBertForSequenceClassification`, `ElectraForSequenceClassification`, `FunnelForSequenceClassification`,
`LongformerForSequenceClassification`, `MobileBertForSequenceClassification`,
`ReformerForSequenceClassification`, `RobertaForSequenceClassification`,
`SqueezeBertForSequenceClassification`, `XLMForSequenceClassification` and `XLNetForSequenceClassification`.
Parameters linked to the tokenizer Parameters linked to the tokenizer
...@@ -260,6 +268,15 @@ class PretrainedConfig(PushToHubMixin): ...@@ -260,6 +268,15 @@ class PretrainedConfig(PushToHubMixin):
# task specific arguments # task specific arguments
self.task_specific_params = kwargs.pop("task_specific_params", None) self.task_specific_params = kwargs.pop("task_specific_params", None)
# regression / multi-label classification
self.problem_type = kwargs.pop("problem_type", None)
allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
if self.problem_type is not None and self.problem_type not in allowed_problem_types:
raise ValueError(
f"The config parameter `problem_type` wasnot understood: received {self.problem_type}"
"but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
)
# TPU arguments # TPU arguments
if kwargs.pop("xla_device", None) is not None: if kwargs.pop("xla_device", None) is not None:
logger.warning( logger.warning(
......
...@@ -21,7 +21,7 @@ from typing import Optional, Tuple ...@@ -21,7 +21,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -970,6 +970,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -970,6 +970,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.albert = AlbertModel(config) self.albert = AlbertModel(config)
self.dropout = nn.Dropout(config.classifier_dropout_prob) self.dropout = nn.Dropout(config.classifier_dropout_prob)
...@@ -1024,13 +1025,23 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ...@@ -1024,13 +1025,23 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -25,7 +25,7 @@ from typing import Optional, Tuple ...@@ -25,7 +25,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1381,7 +1381,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel): ...@@ -1381,7 +1381,7 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
**kwargs **kwargs,
): ):
r""" r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
...@@ -1463,6 +1463,7 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1463,6 +1463,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.bert = BertModel(config) self.bert = BertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
...@@ -1517,14 +1518,23 @@ class BertForSequenceClassification(BertPreTrainedModel): ...@@ -1517,14 +1518,23 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -25,7 +25,7 @@ import torch ...@@ -25,7 +25,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -2609,6 +2609,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel): ...@@ -2609,6 +2609,7 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.bert = BigBirdModel(config) self.bert = BigBirdModel(config)
self.classifier = BigBirdClassificationHead(config) self.classifier = BigBirdClassificationHead(config)
...@@ -2659,13 +2660,23 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel): ...@@ -2659,13 +2660,23 @@ class BigBirdForSequenceClassification(BigBirdPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -22,7 +22,7 @@ from operator import attrgetter ...@@ -22,7 +22,7 @@ from operator import attrgetter
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation from ...activations import ACT2FN, get_activation
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -962,6 +962,7 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel): ...@@ -962,6 +962,7 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.convbert = ConvBertModel(config) self.convbert = ConvBertModel(config)
self.classifier = ConvBertClassificationHead(config) self.classifier = ConvBertClassificationHead(config)
...@@ -1012,13 +1013,23 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel): ...@@ -1012,13 +1013,23 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -24,7 +24,7 @@ import math ...@@ -24,7 +24,7 @@ import math
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import gelu from ...activations import gelu
from ...file_utils import ( from ...file_utils import (
...@@ -579,6 +579,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -579,6 +579,7 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.distilbert = DistilBertModel(config) self.distilbert = DistilBertModel(config)
self.pre_classifier = nn.Linear(config.dim, config.dim) self.pre_classifier = nn.Linear(config.dim, config.dim)
...@@ -631,12 +632,23 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel): ...@@ -631,12 +632,23 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
loss_fct = nn.MSELoss() if self.num_labels == 1:
loss = loss_fct(logits.view(-1), labels.view(-1)) self.config.problem_type = "regression"
else: elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
loss_fct = nn.CrossEntropyLoss() self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + distilbert_output[1:] output = (logits,) + distilbert_output[1:]
......
...@@ -22,7 +22,7 @@ from typing import Optional, Tuple ...@@ -22,7 +22,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation from ...activations import ACT2FN, get_activation
from ...file_utils import ( from ...file_utils import (
...@@ -903,6 +903,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -903,6 +903,7 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.electra = ElectraModel(config) self.electra = ElectraModel(config)
self.classifier = ElectraClassificationHead(config) self.classifier = ElectraClassificationHead(config)
...@@ -953,13 +954,23 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel): ...@@ -953,13 +954,23 @@ class ElectraForSequenceClassification(ElectraPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + discriminator_hidden_states[1:] output = (logits,) + discriminator_hidden_states[1:]
......
...@@ -21,7 +21,7 @@ from typing import Optional, Tuple ...@@ -21,7 +21,7 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -1240,6 +1240,7 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel): ...@@ -1240,6 +1240,7 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.funnel = FunnelBaseModel(config) self.funnel = FunnelBaseModel(config)
self.classifier = FunnelClassificationHead(config, config.num_labels) self.classifier = FunnelClassificationHead(config, config.num_labels)
...@@ -1287,13 +1288,23 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel): ...@@ -1287,13 +1288,23 @@ class FunnelForSequenceClassification(FunnelPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
......
...@@ -21,7 +21,7 @@ from typing import Optional, Tuple ...@@ -21,7 +21,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
...@@ -1803,6 +1803,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1803,6 +1803,7 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.longformer = LongformerModel(config, add_pooling_layer=False) self.longformer = LongformerModel(config, add_pooling_layer=False)
self.classifier = LongformerClassificationHead(config) self.classifier = LongformerClassificationHead(config)
...@@ -1861,13 +1862,23 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel): ...@@ -1861,13 +1862,23 @@ class LongformerForSequenceClassification(LongformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -29,7 +29,7 @@ from typing import Optional, Tuple ...@@ -29,7 +29,7 @@ from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1214,6 +1214,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): ...@@ -1214,6 +1214,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.mobilebert = MobileBertModel(config) self.mobilebert = MobileBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
...@@ -1268,14 +1269,23 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): ...@@ -1268,14 +1269,23 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -26,7 +26,7 @@ import numpy as np ...@@ -26,7 +26,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.autograd.function import Function from torch.autograd.function import Function
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -366,7 +366,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -366,7 +366,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin):
past_buckets_states=None, past_buckets_states=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
**kwargs **kwargs,
): ):
sequence_length = hidden_states.shape[1] sequence_length = hidden_states.shape[1]
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -1045,7 +1045,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): ...@@ -1045,7 +1045,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin):
past_buckets_states=None, past_buckets_states=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
**kwargs **kwargs,
): ):
sequence_length = hidden_states.shape[1] sequence_length = hidden_states.shape[1]
batch_size = hidden_states.shape[0] batch_size = hidden_states.shape[0]
...@@ -2381,6 +2381,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2381,6 +2381,7 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.reformer = ReformerModel(config) self.reformer = ReformerModel(config)
self.classifier = ReformerClassificationHead(config) self.classifier = ReformerClassificationHead(config)
...@@ -2434,13 +2435,23 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel): ...@@ -2434,13 +2435,23 @@ class ReformerForSequenceClassification(ReformerPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -20,7 +20,7 @@ import math ...@@ -20,7 +20,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, gelu from ...activations import ACT2FN, gelu
from ...file_utils import ( from ...file_utils import (
...@@ -1117,6 +1117,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel): ...@@ -1117,6 +1117,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.roberta = RobertaModel(config, add_pooling_layer=False) self.roberta = RobertaModel(config, add_pooling_layer=False)
self.classifier = RobertaClassificationHead(config) self.classifier = RobertaClassificationHead(config)
...@@ -1167,13 +1168,23 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel): ...@@ -1167,13 +1168,23 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -19,7 +19,7 @@ import math ...@@ -19,7 +19,7 @@ import math
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from ...file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
...@@ -733,6 +733,7 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel): ...@@ -733,6 +733,7 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.transformer = SqueezeBertModel(config) self.transformer = SqueezeBertModel(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
...@@ -787,13 +788,23 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel): ...@@ -787,13 +788,23 @@ class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[2:] output = (logits,) + outputs[2:]
......
...@@ -24,7 +24,7 @@ from typing import Optional, Tuple ...@@ -24,7 +24,7 @@ from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from ...activations import gelu from ...activations import gelu
...@@ -779,6 +779,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -779,6 +779,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.transformer = XLMModel(config) self.transformer = XLMModel(config)
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
...@@ -836,13 +837,23 @@ class XLMForSequenceClassification(XLMPreTrainedModel): ...@@ -836,13 +837,23 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
......
...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple ...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import functional as F from torch.nn import functional as F
from ...activations import ACT2FN from ...activations import ACT2FN
...@@ -1488,6 +1488,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1488,6 +1488,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.config = config
self.transformer = XLNetModel(config) self.transformer = XLNetModel(config)
self.sequence_summary = SequenceSummary(config) self.sequence_summary = SequenceSummary(config)
...@@ -1551,13 +1552,23 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel): ...@@ -1551,13 +1552,23 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.num_labels == 1: if self.config.problem_type is None:
# We are doing regression if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels)
else: elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + transformer_outputs[1:] output = (logits,) + transformer_outputs[1:]
......
...@@ -230,6 +230,8 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -230,6 +230,8 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
else () else ()
) )
test_sequence_classification_problem_types = True
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
......
...@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -439,6 +439,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
else () else ()
) )
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
test_sequence_classification_problem_types = True
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -433,6 +433,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -433,6 +433,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
# head masking & pruning is currently not supported for big bird # head masking & pruning is currently not supported for big bird
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_sequence_classification_problem_types = True
# torchscript should be possible, but takes prohibitively long to test. # torchscript should be possible, but takes prohibitively long to test.
# Also torchscript is not an important feature to have in the beginning. # Also torchscript is not an important feature to have in the beginning.
......
...@@ -89,6 +89,7 @@ class ModelTesterMixin: ...@@ -89,6 +89,7 @@ class ModelTesterMixin:
test_missing_keys = True test_missing_keys = True
test_model_parallel = False test_model_parallel = False
is_encoder_decoder = False is_encoder_decoder = False
test_sequence_classification_problem_types = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
...@@ -1238,6 +1239,42 @@ class ModelTesterMixin: ...@@ -1238,6 +1239,42 @@ class ModelTesterMixin:
model.parallelize() model.parallelize()
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def test_problem_types(self):
if not self.test_sequence_classification_problem_types:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
problem_types = [
{"title": "multi_label_classification", "num_labels": 2, "dtype": torch.float},
{"title": "single_label_classification", "num_labels": 1, "dtype": torch.long},
{"title": "regression", "num_labels": 1, "dtype": torch.float},
]
for model_class in self.all_model_classes:
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
continue
for problem_type in problem_types:
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
config.problem_type = problem_type["title"]
config.num_labels = problem_type["num_labels"]
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
if problem_type["num_labels"] > 1:
inputs["labels"] = inputs["labels"].unsqueeze(1).repeat(1, problem_type["num_labels"])
inputs["labels"] = inputs["labels"].to(problem_type["dtype"])
loss = model(**inputs).loss
loss.backward()
global_rng = random.Random() global_rng = random.Random()
......
...@@ -260,6 +260,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -260,6 +260,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self): def setUp(self):
self.model_tester = ConvBertModelTester(self) self.model_tester = ConvBertModelTester(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment