Unverified Commit c40c7e21 authored by abhishek thakur's avatar abhishek thakur Committed by GitHub
Browse files

Add multi-class, multi-label and regression to transformers (#11012)



* add to  bert

* review comments

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* self.config.problem_type

* fix style

* fix

* fin

* fix

* update doc

* fix

* test

* Test more problem types

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix

* remove

* fix

* quality

* make fix-copies

* remove test
Co-authored-by: default avatarabhishek thakur <abhishekkrthakur@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarLysandre <lysandre.debut@reseau.eseo.fr>
parent 7c622482
......@@ -211,6 +211,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = True
test_torchscript = True
test_resize_embeddings = True
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = DistilBertModelTester(self)
......
......@@ -287,6 +287,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -360,6 +360,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -274,6 +274,7 @@ class LongformerModelTester:
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
test_pruning = False # pruning is not supported
test_torchscript = False
test_sequence_classification_problem_types = True
all_model_classes = (
(
......
......@@ -267,6 +267,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
if is_torch_available()
else ()
)
test_sequence_classification_problem_types = True
# special case for ForPreTraining model
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -590,6 +590,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
test_pruning = False
test_headmasking = False
test_torchscript = False
test_sequence_classification_problem_types = True
def prepare_kwargs(self):
return {
......
......@@ -351,6 +351,7 @@ class RobertaModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
else ()
)
all_generative_model_classes = (RobertaForCausalLM,) if is_torch_available() else ()
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = RobertaModelTester(self)
......
......@@ -231,6 +231,7 @@ class SqueezeBertModelTest(ModelTesterMixin, unittest.TestCase):
test_torchscript = True
test_resize_embeddings = True
test_head_masking = False
test_sequence_classification_problem_types = True
def setUp(self):
self.model_tester = SqueezeBertModelTester(self)
......
......@@ -349,6 +349,7 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
all_generative_model_classes = (
(XLMWithLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_sequence_classification_problem_types = True
# XLM has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
......@@ -526,6 +526,7 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase)
(XLNetLMHeadModel,) if is_torch_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
test_pruning = False
test_sequence_classification_problem_types = True
# XLNet has 2 QA models -> need to manually set the correct labels for one of them here
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment