Unverified Commit c28bc80b authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Generalize problem_type to all sequence classification models (#14180)

* Generalize problem_type to all classification models

* Missing import

* Deberta BC and fix tests

* Fix template

* Missing imports

* Revert change to reformer test

* Fix style
parent 4ab6a4a0
...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple ...@@ -22,7 +22,7 @@ from typing import List, Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...file_utils import ( from ...file_utils import (
ModelOutput, ModelOutput,
...@@ -1234,13 +1234,26 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel): ...@@ -1234,13 +1234,26 @@ class TransfoXLForSequenceClassification(TransfoXLPreTrainedModel):
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(pooled_logits.view(-1), labels.to(self.dtype).view(-1)) if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else: else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
if not return_dict: if not return_dict:
output = (pooled_logits,) + transformer_outputs[1:] output = (pooled_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -24,7 +24,7 @@ import torch ...@@ -24,7 +24,7 @@ import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from packaging import version from packaging import version
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -1265,14 +1265,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt ...@@ -1265,14 +1265,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1: if self.num_labels == 1:
# We are doing regression self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss() loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1)) if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else: else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
...@@ -1564,7 +1576,7 @@ from typing import Optional, Tuple ...@@ -1564,7 +1576,7 @@ from typing import Optional, Tuple
import torch import torch
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
...@@ -2981,9 +2993,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt ...@@ -2981,9 +2993,26 @@ class {{cookiecutter.camelcase_modelname}}ForSequenceClassification({{cookiecutt
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.problem_type is None:
if self.config.num_labels == 1:
self.config.problem_type = "regression"
elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.config.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss() loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict: if not return_dict:
output = (logits,) + outputs[1:] output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
......
...@@ -234,8 +234,6 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -234,8 +234,6 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
fx_ready_model_classes = all_model_classes fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
......
...@@ -446,7 +446,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -446,7 +446,6 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else () all_generative_model_classes = (BertLMHeadModel,) if is_torch_available() else ()
fx_ready_model_classes = all_model_classes fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -435,7 +435,6 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -435,7 +435,6 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
# head masking & pruning is currently not supported for big bird # head masking & pruning is currently not supported for big bird
test_head_masking = False test_head_masking = False
test_pruning = False test_pruning = False
test_sequence_classification_problem_types = True
# torchscript should be possible, but takes prohibitively long to test. # torchscript should be possible, but takes prohibitively long to test.
# Also torchscript is not an important feature to have in the beginning. # Also torchscript is not an important feature to have in the beginning.
......
...@@ -113,7 +113,6 @@ class ModelTesterMixin: ...@@ -113,7 +113,6 @@ class ModelTesterMixin:
test_missing_keys = True test_missing_keys = True
test_model_parallel = False test_model_parallel = False
is_encoder_decoder = False is_encoder_decoder = False
test_sequence_classification_problem_types = False
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
...@@ -387,12 +386,13 @@ class ModelTesterMixin: ...@@ -387,12 +386,13 @@ class ModelTesterMixin:
if not self.model_tester.is_training: if not self.model_tester.is_training:
return return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING): if model_class in get_values(MODEL_MAPPING):
continue continue
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.train() model.train()
...@@ -401,14 +401,14 @@ class ModelTesterMixin: ...@@ -401,14 +401,14 @@ class ModelTesterMixin:
loss.backward() loss.backward()
def test_training_gradient_checkpointing(self): def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training: if not self.model_tester.is_training:
return return
for model_class in self.all_model_classes:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.use_cache = False config.use_cache = False
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing: if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue continue
model = model_class(config) model = model_class(config)
...@@ -1842,9 +1842,6 @@ class ModelTesterMixin: ...@@ -1842,9 +1842,6 @@ class ModelTesterMixin:
model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2) model.generate(**cast_to_device(inputs_dict, "cuda:0"), num_beams=2)
def test_problem_types(self): def test_problem_types(self):
if not self.test_sequence_classification_problem_types:
return
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
problem_types = [ problem_types = [
...@@ -1880,7 +1877,11 @@ class ModelTesterMixin: ...@@ -1880,7 +1877,11 @@ class ModelTesterMixin:
# See https://github.com/huggingface/transformers/issues/11780 # See https://github.com/huggingface/transformers/issues/11780
with warnings.catch_warnings(record=True) as warning_list: with warnings.catch_warnings(record=True) as warning_list:
loss = model(**inputs).loss loss = model(**inputs).loss
self.assertListEqual(warning_list, []) for w in warning_list:
if "Using a target size that is different to the input size" in str(w.message):
raise ValueError(
f"Something is going wrong in the regression problem: intercepted {w.message}"
)
loss.backward() loss.backward()
...@@ -2184,7 +2185,6 @@ class ModelPushToHubTester(unittest.TestCase): ...@@ -2184,7 +2185,6 @@ class ModelPushToHubTester(unittest.TestCase):
f.write(FAKE_MODEL_CODE) f.write(FAKE_MODEL_CODE)
repo.push_to_hub() repo.push_to_hub()
print(os.listdir(tmp_dir))
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True) new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
for p1, p2 in zip(model.parameters(), new_model.parameters()): for p1, p2 in zip(model.parameters(), new_model.parameters()):
......
...@@ -262,7 +262,6 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -262,7 +262,6 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
) )
test_pruning = False test_pruning = False
test_head_masking = False test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self): def setUp(self):
self.model_tester = ConvBertModelTester(self) self.model_tester = ConvBertModelTester(self)
......
...@@ -214,7 +214,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -214,7 +214,6 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = True test_pruning = True
test_torchscript = True test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
test_sequence_classification_problem_types = True
test_resize_position_embeddings = True test_resize_position_embeddings = True
def setUp(self): def setUp(self):
......
...@@ -291,7 +291,6 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -291,7 +291,6 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
) )
fx_ready_model_classes = all_model_classes fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -362,7 +362,6 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -362,7 +362,6 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available() if is_torch_available()
else () else ()
) )
test_sequence_classification_problem_types = True
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -278,7 +278,6 @@ class LongformerModelTester: ...@@ -278,7 +278,6 @@ class LongformerModelTester:
class LongformerModelTest(ModelTesterMixin, unittest.TestCase): class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported test_pruning = False # pruning is not supported
test_torchscript = False test_torchscript = False
test_sequence_classification_problem_types = True
all_model_classes = ( all_model_classes = (
( (
......
...@@ -271,7 +271,6 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -271,7 +271,6 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
) )
fx_ready_model_classes = all_model_classes fx_ready_model_classes = all_model_classes
fx_dynamic_ready_model_classes = all_model_classes fx_dynamic_ready_model_classes = all_model_classes
test_sequence_classification_problem_types = True
# special case for ForPreTraining model # special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -143,7 +143,7 @@ class OpenAIGPTModelTester: ...@@ -143,7 +143,7 @@ class OpenAIGPTModelTester:
model = OpenAIGPTForSequenceClassification(config) model = OpenAIGPTForSequenceClassification(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
# print(config.num_labels, sequence_labels.size())
sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size) sequence_labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels) result = model(input_ids, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels)) self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
......
...@@ -795,6 +795,10 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation ...@@ -795,6 +795,10 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, Generation
[expected_shape] * len(iter_hidden_states), [expected_shape] * len(iter_hidden_states),
) )
def test_problem_types(self):
# Fails because the sequence length is not a multiple of 4
pass
@require_torch @require_torch
@require_sentencepiece @require_sentencepiece
......
...@@ -356,7 +356,6 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas ...@@ -356,7 +356,6 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
else () else ()
) )
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
test_sequence_classification_problem_types = True
def setUp(self): def setUp(self):
self.model_tester = RobertaModelTester(self) self.model_tester = RobertaModelTester(self)
......
...@@ -232,7 +232,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -232,7 +232,6 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = True test_torchscript = True
test_resize_embeddings = True test_resize_embeddings = True
test_head_masking = False test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self): def setUp(self):
self.model_tester = SqueezeBertModelTester(self) self.model_tester = SqueezeBertModelTester(self)
......
...@@ -350,7 +350,6 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -350,7 +350,6 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = ( all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else () (XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_sequence_classification_problem_types = True
# XLM has 2 QA models -> need to manually set the correct labels for one of them here # XLM has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
...@@ -527,7 +527,6 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase) ...@@ -527,7 +527,6 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
(XLNetLMHeadModel,) if is_torch_available() else () (XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable ) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False test_pruning = False
test_sequence_classification_problem_types = True
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here # XLNet has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment