Unverified Commit c28bc80b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Generalize problem_type to all sequence classification models (#14180)

* Generalize problem_type to all classification models

* Missing import

* Deberta BC and fix tests

* Fix template

* Missing imports

* Revert change to reformer test

* Fix style
parent 4ab6a4a0
......@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import (
ModelOutput,
......@@ -1234,13 +1234,26 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1))
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
......
......@@ -24,7 +24,7 @@ import torch
import torch.utils.checkpoint
from packaging import version
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
......@@ -1265,14 +1265,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
# We are doing regression
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......@@ -1564,7 +1576,7 @@ from typing import Optional, Tuple
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...file_utils import (
......@@ -2981,9 +2993,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
......
......@@ -234,8 +234,6 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
......
......@@ -446,7 +446,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -435,7 +435,6 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
# head masking & pruning is currently not supported for big bird
test_head_masking = False
test_pruning = False
test_sequence_classification_problem_types = True
# torchscript should be possible, but takes prohibitively long to test.
# Also torchscript is not an important feature to have in the beginning.
......
......@@ -113,7 +113,6 @@ class ModelTesterMixin:
test_missing_keys = True
test_model_parallel = False
is_encoder_decoder = False
test_sequence_classification_problem_types = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict)
......@@ -387,12 +386,13 @@ class ModelTesterMixin:
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING):
continue
model = model_class(config)
model.to(torch_device)
model.train()
......@@ -401,14 +401,14 @@ class ModelTesterMixin:
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
model = model_class(config)
......@@ -1842,9 +1842,6 @@ class ModelTesterMixin:
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def test_problem_types(self):
if not self.test_sequence_classification_problem_types:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
problem_types = [
......@@ -1880,7 +1877,11 @@ class ModelTesterMixin:
# See https://github.com/huggingface/transformers/issues/11780
with warnings.catch_warnings(record=True) as warning_list:
loss = model(**inputs).loss
self.assertListEqual(warning_list, [])
for w in warning_list:
if "Using a target size that is different to the input size" in str(w.message):
raise ValueError(
f"Something is going wrong in the regression problem: intercepted {w.message}"
)
loss.backward()
......@@ -2184,7 +2185,6 @@ class ModelPushToHubTester(unittest.TestCase):
f.write(FAKE_MODEL_CODE)
repo.push_to_hub()
print(os.listdir(tmp_dir))
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
for p1, p2 in zip(model.parameters(), new_model.parameters()):
......
......@@ -262,7 +262,6 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
)
test_pruning = False
test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = ConvBertModelTester(self)
......
......@@ -214,7 +214,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_sequence_classification_problem_types = True
test_resize_position_embeddings = True
def setUp(self):
......
......@@ -291,7 +291,6 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -362,7 +362,6 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -278,7 +278,6 @@ class LongformerModelTester:
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported
test_torchscript = False
test_sequence_classification_problem_types = True
all_model_classes = (
(
......
......@@ -271,7 +271,6 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
)
fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -143,7 +143,7 @@ class OpenAIGPTModelTester:
model = OpenAIGPTForSequenceClassification(config)
model.to(torch_device)
model.eval()
# print(config.num_labels, sequence_labels.size())
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
......
......@@ -795,6 +795,10 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
[expected_shape] * len(iter_hidden_states),
)
def test_problem_types(self):
# Fails because the sequence length is not a multiple of 4
pass
@require_torch
@require_sentencepiece
......
......@@ -356,7 +356,6 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
else ()
)
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = RobertaModelTester(self)
......
......@@ -232,7 +232,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = True
test_resize_embeddings = True
test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = SqueezeBertModelTester(self)
......
......@@ -350,7 +350,6 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_sequence_classification_problem_types = True
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -527,7 +527,6 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
test_sequence_classification_problem_types = True
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment