Unverified Commit ba8b1f47 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add support for multiple models for one config in auto classes (#11150)

* Add support for multiple models for one config in auto classes

* Use get_values everywhere

* Prettier doc
parent 97ccf67b
...@@ -32,6 +32,7 @@ from transformers import ( ...@@ -32,6 +32,7 @@ from transformers import (
is_torch_available, is_torch_available,
) )
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.models.auto import get_values
from transformers.testing_utils import require_scatter, require_torch, slow, torch_device from transformers.testing_utils import require_scatter, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -425,7 +426,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -425,7 +426,7 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = { inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim > 1 if isinstance(v, torch.Tensor) and v.ndim > 1
...@@ -434,9 +435,9 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -434,9 +435,9 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
} }
if return_labels: if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device) inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING.values(): elif model_class in get_values(MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
...@@ -457,17 +458,17 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -457,17 +458,17 @@ class TapasModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.batch_size, dtype=torch.float, device=torch_device self.model_tester.batch_size, dtype=torch.float, device=torch_device
) )
elif model_class in [ elif model_class in [
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(), *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(), *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
]: ]:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
elif model_class in [ elif model_class in [
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*MODEL_FOR_CAUSAL_LM_MAPPING.values(), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*MODEL_FOR_MASKED_LM_MAPPING.values(), *get_values(MODEL_FOR_MASKED_LM_MAPPING),
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]: ]:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import AlbertConfig, is_tf_available from transformers import AlbertConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -249,7 +250,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -249,7 +250,7 @@ class TFAlbertModelTest(TFModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["sentence_order_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict return inputs_dict
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import tempfile
import unittest import unittest
from transformers import is_tf_available from transformers import is_tf_available
...@@ -39,6 +40,8 @@ if is_tf_available(): ...@@ -39,6 +40,8 @@ if is_tf_available():
TFBertForQuestionAnswering, TFBertForQuestionAnswering,
TFBertForSequenceClassification, TFBertForSequenceClassification,
TFBertModel, TFBertModel,
TFFunnelBaseModel,
TFFunnelModel,
TFGPT2LMHeadModel, TFGPT2LMHeadModel,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
TFT5ForConditionalGeneration, TFT5ForConditionalGeneration,
...@@ -176,6 +179,21 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -176,6 +179,21 @@ class TFAutoModelTest(unittest.TestCase):
self.assertEqual(model.num_parameters(), 14410) self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410) self.assertEqual(model.num_parameters(only_trainable=True), 14410)
def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = TFAutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, TFFunnelModel)
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = TFAutoModel.from_config(config)
self.assertIsInstance(model, TFFunnelBaseModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = TFAutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, TFFunnelBaseModel)
def test_parents_and_children_in_mappings(self): def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered # Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models # by the parents and will return the wrong configuration type when using auto models
...@@ -197,4 +215,12 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -197,4 +215,12 @@ class TFAutoModelTest(unittest.TestCase):
for parent_config, parent_model in mapping[: index + 1]: for parent_config, parent_model in mapping[: index + 1]:
with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"): with self.subTest(msg=f"Testing if {child_config.__name__} is child of {parent_config.__name__}"):
self.assertFalse(issubclass(child_config, parent_config)) self.assertFalse(issubclass(child_config, parent_config))
self.assertFalse(issubclass(child_model, parent_model))
# Tuplify child_model and parent_model since some of them could be tuples.
if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)
for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import BertConfig, is_tf_available from transformers import BertConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -282,7 +283,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -282,7 +283,7 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict return inputs_dict
......
...@@ -25,6 +25,7 @@ from importlib import import_module ...@@ -25,6 +25,7 @@ from importlib import import_module
from typing import List, Tuple from typing import List, Tuple
from transformers import is_tf_available from transformers import is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import ( from transformers.testing_utils import (
_tf_gpu_memory_limit, _tf_gpu_memory_limit,
is_pt_tf_cross_test, is_pt_tf_cross_test,
...@@ -89,7 +90,7 @@ class TFModelTesterMixin: ...@@ -89,7 +90,7 @@ class TFModelTesterMixin:
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict: def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> dict:
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = { inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1)) k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
if isinstance(v, tf.Tensor) and v.ndim > 0 if isinstance(v, tf.Tensor) and v.ndim > 0
...@@ -98,21 +99,21 @@ class TFModelTesterMixin: ...@@ -98,21 +99,21 @@ class TFModelTesterMixin:
} }
if return_labels: if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): elif model_class in get_values(TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(): elif model_class in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(): elif model_class in get_values(TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in [ elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(), *get_values(TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(), *get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING),
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(), *get_values(TF_MODEL_FOR_MASKED_LM_MAPPING),
*TF_MODEL_FOR_PRETRAINING_MAPPING.values(), *get_values(TF_MODEL_FOR_PRETRAINING_MAPPING),
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(), *get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]: ]:
inputs_dict["labels"] = tf.zeros( inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32 (self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
...@@ -580,7 +581,7 @@ class TFModelTesterMixin: ...@@ -580,7 +581,7 @@ class TFModelTesterMixin:
), ),
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"), "input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
} }
elif model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32") input_ids = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
else: else:
input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32") input_ids = tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32")
...@@ -796,9 +797,9 @@ class TFModelTesterMixin: ...@@ -796,9 +797,9 @@ class TFModelTesterMixin:
def test_model_common_attributes(self): def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
list_lm_models = ( list_lm_models = (
list(TF_MODEL_FOR_CAUSAL_LM_MAPPING.values()) get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
+ list(TF_MODEL_FOR_MASKED_LM_MAPPING.values()) + get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
+ list(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values()) + get_values(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING)
) )
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
...@@ -1128,7 +1129,7 @@ class TFModelTesterMixin: ...@@ -1128,7 +1129,7 @@ class TFModelTesterMixin:
] ]
loss_size = tf.size(added_label) loss_size = tf.size(added_label)
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(): if model.__class__ in get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING):
# if loss is causal lm loss, labels are shift, so that one label per batch # if loss is causal lm loss, labels are shift, so that one label per batch
# is cut # is cut
loss_size = loss_size - self.model_tester.batch_size loss_size = loss_size - self.model_tester.batch_size
......
...@@ -19,6 +19,8 @@ import os ...@@ -19,6 +19,8 @@ import os
import re import re
from pathlib import Path from pathlib import Path
from transformers.models.auto import get_values
# All paths are set with the intent you should run this script from the root of the repo with the command # All paths are set with the intent you should run this script from the root of the repo with the command
# python utils/check_repo.py # python utils/check_repo.py
...@@ -86,7 +88,6 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -86,7 +88,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"DPRReader", "DPRReader",
"DPRSpanPredictor", "DPRSpanPredictor",
"FlaubertForQuestionAnswering", "FlaubertForQuestionAnswering",
"FunnelBaseModel",
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"OpenAIGPTDoubleHeadsModel", "OpenAIGPTDoubleHeadsModel",
"RagModel", "RagModel",
...@@ -95,7 +96,6 @@ IGNORE_NON_AUTO_CONFIGURED = [ ...@@ -95,7 +96,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"T5Stack", "T5Stack",
"TFDPRReader", "TFDPRReader",
"TFDPRSpanPredictor", "TFDPRSpanPredictor",
"TFFunnelBaseModel",
"TFGPT2DoubleHeadsModel", "TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel",
"TFRagModel", "TFRagModel",
...@@ -153,7 +153,7 @@ def get_model_modules(): ...@@ -153,7 +153,7 @@ def get_model_modules():
def get_models(module): def get_models(module):
""" Get the objects in module that are models.""" """ Get the objects in module that are models."""
models = [] models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel) model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module): for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name: if "Pretrained" in attr_name or "PreTrained" in attr_name:
continue continue
...@@ -249,10 +249,13 @@ def get_all_auto_configured_models(): ...@@ -249,10 +249,13 @@ def get_all_auto_configured_models():
result = set() # To avoid duplicates we concatenate all model classes in a set. result = set() # To avoid duplicates we concatenate all model classes in a set.
for attr_name in dir(transformers.models.auto.modeling_auto): for attr_name in dir(transformers.models.auto.modeling_auto):
if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"): if attr_name.startswith("MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_auto, attr_name).values()) result = result | set(get_values(getattr(transformers.models.auto.modeling_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_tf_auto): for attr_name in dir(transformers.models.auto.modeling_tf_auto):
if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"): if attr_name.startswith("TF_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(getattr(transformers.models.auto.modeling_tf_auto, attr_name).values()) result = result | set(get_values(getattr(transformers.models.auto.modeling_tf_auto, attr_name)))
for attr_name in dir(transformers.models.auto.modeling_flax_auto):
if attr_name.startswith("FLAX_MODEL_") and attr_name.endswith("MAPPING"):
result = result | set(get_values(getattr(transformers.models.auto.modeling_flax_auto, attr_name)))
return [cls.__name__ for cls in result] return [cls.__name__ for cls in result]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment