Unverified Commit ba8b1f47 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add support for multiple models for one config in auto classes (#11150)

* Add support for multiple models for one config in auto classes

* Use get_values everywhere

* Prettier doc
parent 97ccf67b
...@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC): ...@@ -387,6 +387,7 @@ class FlaxPreTrainedModel(ABC):
# get abs dir # get abs dir
save_directory = os.path.abspath(save_directory) save_directory = os.path.abspath(save_directory)
# save config as well # save config as well
self.config.architectures = [self.__class__.__name__[4:]]
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
# save model # save model
......
...@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): ...@@ -1037,6 +1037,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
logger.info(f"Saved model created in {saved_model_dir}") logger.info(f"Saved model created in {saved_model_dir}")
# Save configuration file # Save configuration file
self.config.architectures = [self.__class__.__name__[2:]]
self.config.save_pretrained(save_directory) self.config.save_pretrained(save_directory)
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
......
...@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i ...@@ -22,6 +22,7 @@ from ...file_utils import _BaseLazyModule, is_flax_available, is_tf_available, i
_import_structure = { _import_structure = {
"auto_factory": ["get_values"],
"configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"], "configuration_auto": ["ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CONFIG_MAPPING", "MODEL_NAMES_MAPPING", "AutoConfig"],
"feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"], "feature_extraction_auto": ["FEATURE_EXTRACTOR_MAPPING", "AutoFeatureExtractor"],
"tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"], "tokenization_auto": ["TOKENIZER_MAPPING", "AutoTokenizer"],
...@@ -104,6 +105,7 @@ if is_flax_available(): ...@@ -104,6 +105,7 @@ if is_flax_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .auto_factory import get_values
from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig from .configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, CONFIG_MAPPING, MODEL_NAMES_MAPPING, AutoConfig
from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor from .feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, AutoFeatureExtractor
from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer from .tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
......
...@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """ ...@@ -328,6 +328,26 @@ FROM_PRETRAINED_FLAX_DOCSTRING = """
""" """
def _get_model_class(config, model_mapping):
supported_models = model_mapping[type(config)]
if not isinstance(supported_models, (list, tuple)):
return supported_models
name_to_model = {model.__name__: model for model in supported_models}
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in name_to_model:
return name_to_model[arch]
elif f"TF{arch}" in name_to_model:
return name_to_model[f"TF{arch}"]
elif f"Flax{arch}" in name_to_model:
return name_to_model[f"Flax{arch}"]
# If not architecture is set in the config or match the supported models, the first element of the tuple is the
# defaults.
return supported_models[0]
class _BaseAutoModelClass: class _BaseAutoModelClass:
# Base class for auto models. # Base class for auto models.
_model_mapping = None _model_mapping = None
...@@ -341,7 +361,8 @@ class _BaseAutoModelClass: ...@@ -341,7 +361,8 @@ class _BaseAutoModelClass:
def from_config(cls, config, **kwargs): def from_config(cls, config, **kwargs):
if type(config) in cls._model_mapping.keys(): if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)](config, **kwargs) model_class = _get_model_class(config, cls._model_mapping)
return model_class(config, **kwargs)
raise ValueError( raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
...@@ -356,9 +377,8 @@ class _BaseAutoModelClass: ...@@ -356,9 +377,8 @@ class _BaseAutoModelClass:
) )
if type(config) in cls._model_mapping.keys(): if type(config) in cls._model_mapping.keys():
return cls._model_mapping[type(config)].from_pretrained( model_class = _get_model_class(config, cls._model_mapping)
pretrained_model_name_or_path, *model_args, config=config, **kwargs return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
)
raise ValueError( raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
...@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca ...@@ -418,3 +438,14 @@ def auto_class_factory(name, model_mapping, checkpoint_for_example="bert-base-ca
from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained) from_pretrained = replace_list_option_in_docstrings(model_mapping)(from_pretrained)
new_class.from_pretrained = classmethod(from_pretrained) new_class.from_pretrained = classmethod(from_pretrained)
return new_class return new_class
def get_values(model_mapping):
result = []
for model in model_mapping.values():
if isinstance(model, (list, tuple)):
result += list(model)
else:
result.append(model)
return result
...@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -247,29 +247,38 @@ MODEL_NAMES_MAPPING = OrderedDict(
) )
def _get_class_name(model_class):
if isinstance(model_class, (list, tuple)):
return " or ".join([f":class:`~transformers.{c.__name__}`" for c in model_class])
return f":class:`~transformers.{model_class.__name__}`"
def _list_model_options(indent, config_to_class=None, use_model_types=True): def _list_model_options(indent, config_to_class=None, use_model_types=True):
if config_to_class is None and not use_model_types: if config_to_class is None and not use_model_types:
raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.") raise ValueError("Using `use_model_types=False` requires a `config_to_class` dictionary.")
if use_model_types: if use_model_types:
if config_to_class is None: if config_to_class is None:
model_type_to_name = {model_type: config.__name__ for model_type, config in CONFIG_MAPPING.items()} model_type_to_name = {
model_type: f":class:`~transformers.{config.__name__}`"
for model_type, config in CONFIG_MAPPING.items()
}
else: else:
model_type_to_name = { model_type_to_name = {
model_type: config_to_class[config].__name__ model_type: _get_class_name(config_to_class[config])
for model_type, config in CONFIG_MAPPING.items() for model_type, config in CONFIG_MAPPING.items()
if config in config_to_class if config in config_to_class
} }
lines = [ lines = [
f"{indent}- **{model_type}** -- :class:`~transformers.{model_type_to_name[model_type]}` ({MODEL_NAMES_MAPPING[model_type]} model)" f"{indent}- **{model_type}** -- {model_type_to_name[model_type]} ({MODEL_NAMES_MAPPING[model_type]} model)"
for model_type in sorted(model_type_to_name.keys()) for model_type in sorted(model_type_to_name.keys())
] ]
else: else:
config_to_name = {config.__name__: clas.__name__ for config, clas in config_to_class.items()} config_to_name = {config.__name__: _get_class_name(clas) for config, clas in config_to_class.items()}
config_to_model_name = { config_to_model_name = {
config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items() config.__name__: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING.items()
} }
lines = [ lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: :class:`~transformers.{config_to_name[config_name]}` ({config_to_model_name[config_name]} model)" f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
for config_name in sorted(config_to_name.keys()) for config_name in sorted(config_to_name.keys())
] ]
return "\n".join(lines) return "\n".join(lines)
......
...@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import ( ...@@ -124,6 +124,7 @@ from ..flaubert.modeling_flaubert import (
) )
from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel from ..fsmt.modeling_fsmt import FSMTForConditionalGeneration, FSMTModel
from ..funnel.modeling_funnel import ( from ..funnel.modeling_funnel import (
FunnelBaseModel,
FunnelForMaskedLM, FunnelForMaskedLM,
FunnelForMultipleChoice, FunnelForMultipleChoice,
FunnelForPreTraining, FunnelForPreTraining,
...@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict( ...@@ -377,7 +378,7 @@ MODEL_MAPPING = OrderedDict(
(CTRLConfig, CTRLModel), (CTRLConfig, CTRLModel),
(ElectraConfig, ElectraModel), (ElectraConfig, ElectraModel),
(ReformerConfig, ReformerModel), (ReformerConfig, ReformerModel),
(FunnelConfig, FunnelModel), (FunnelConfig, (FunnelModel, FunnelBaseModel)),
(LxmertConfig, LxmertModel), (LxmertConfig, LxmertModel),
(BertGenerationConfig, BertGenerationEncoder), (BertGenerationConfig, BertGenerationEncoder),
(DebertaConfig, DebertaModel), (DebertaConfig, DebertaModel),
......
...@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import ( ...@@ -91,6 +91,7 @@ from ..flaubert.modeling_tf_flaubert import (
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
) )
from ..funnel.modeling_tf_funnel import ( from ..funnel.modeling_tf_funnel import (
TFFunnelBaseModel,
TFFunnelForMaskedLM, TFFunnelForMaskedLM,
TFFunnelForMultipleChoice, TFFunnelForMultipleChoice,
TFFunnelForPreTraining, TFFunnelForPreTraining,
...@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict( ...@@ -242,7 +243,7 @@ TF_MODEL_MAPPING = OrderedDict(
(XLMConfig, TFXLMModel), (XLMConfig, TFXLMModel),
(CTRLConfig, TFCTRLModel), (CTRLConfig, TFCTRLModel),
(ElectraConfig, TFElectraModel), (ElectraConfig, TFElectraModel),
(FunnelConfig, TFFunnelModel), (FunnelConfig, (TFFunnelModel, TFFunnelBaseModel)),
(DPRConfig, TFDPRQuestionEncoder), (DPRConfig, TFDPRQuestionEncoder),
(MPNetConfig, TFMPNetModel), (MPNetConfig, TFMPNetModel),
(BartConfig, TFBartModel), (BartConfig, TFBartModel),
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -234,7 +235,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -234,7 +235,7 @@ class AlbertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
import tempfile
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
...@@ -46,6 +47,8 @@ if is_torch_available(): ...@@ -46,6 +47,8 @@ if is_torch_available():
BertForSequenceClassification, BertForSequenceClassification,
BertForTokenClassification, BertForTokenClassification,
BertModel, BertModel,
FunnelBaseModel,
FunnelModel,
GPT2Config, GPT2Config,
GPT2LMHeadModel, GPT2LMHeadModel,
RobertaForMaskedLM, RobertaForMaskedLM,
...@@ -218,6 +221,21 @@ class AutoModelTest(unittest.TestCase): ...@@ -218,6 +221,21 @@ class AutoModelTest(unittest.TestCase):
self.assertEqual(model.num_parameters(), 14410) self.assertEqual(model.num_parameters(), 14410)
self.assertEqual(model.num_parameters(only_trainable=True), 14410) self.assertEqual(model.num_parameters(only_trainable=True), 14410)
def test_from_pretrained_with_tuple_values(self):
# For the auto model mapping, FunnelConfig has two models: FunnelModel and FunnelBaseModel
model = AutoModel.from_pretrained("sgugger/funnel-random-tiny")
self.assertIsInstance(model, FunnelModel)
config = copy.deepcopy(model.config)
config.architectures = ["FunnelBaseModel"]
model = AutoModel.from_config(config)
self.assertIsInstance(model, FunnelBaseModel)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = AutoModel.from_pretrained(tmp_dir)
self.assertIsInstance(model, FunnelBaseModel)
def test_parents_and_children_in_mappings(self): def test_parents_and_children_in_mappings(self):
# Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered # Test that the children are placed before the parents in the mappings, as the `instanceof` will be triggered
# by the parents and will return the wrong configuration type when using auto models # by the parents and will return the wrong configuration type when using auto models
...@@ -242,6 +260,12 @@ class AutoModelTest(unittest.TestCase): ...@@ -242,6 +260,12 @@ class AutoModelTest(unittest.TestCase):
assert not issubclass( assert not issubclass(
child_config, parent_config child_config, parent_config
), f"{child_config.__name__} is child of {parent_config.__name__}" ), f"{child_config.__name__} is child of {parent_config.__name__}"
assert not issubclass(
child_model, parent_model # Tuplify child_model and parent_model since some of them could be tuples.
), f"{child_config.__name__} is child of {parent_config.__name__}" if not isinstance(child_model, (list, tuple)):
child_model = (child_model,)
if not isinstance(parent_model, (list, tuple)):
parent_model = (parent_model,)
for child, parent in [(a, b) for a in child_model for b in parent_model]:
assert not issubclass(child, parent), f"{child.__name__} is child of {parent.__name__}"
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -444,7 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -444,7 +445,7 @@ class BertModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer from transformers.models.big_bird.tokenization_big_bird import BigBirdTokenizer
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -458,7 +459,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -458,7 +459,7 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
......
...@@ -24,6 +24,7 @@ from typing import List, Tuple ...@@ -24,6 +24,7 @@ from typing import List, Tuple
from transformers import is_torch_available from transformers import is_torch_available
from transformers.file_utils import WEIGHTS_NAME from transformers.file_utils import WEIGHTS_NAME
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
...@@ -79,7 +80,7 @@ class ModelTesterMixin: ...@@ -79,7 +80,7 @@ class ModelTesterMixin:
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False): def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict = { inputs_dict = {
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous() k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
if isinstance(v, torch.Tensor) and v.ndim > 1 if isinstance(v, torch.Tensor) and v.ndim > 1
...@@ -88,9 +89,9 @@ class ModelTesterMixin: ...@@ -88,9 +89,9 @@ class ModelTesterMixin:
} }
if return_labels: if return_labels:
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values(): if model_class in get_values(MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device) inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): elif model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["start_positions"] = torch.zeros( inputs_dict["start_positions"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
...@@ -98,18 +99,18 @@ class ModelTesterMixin: ...@@ -98,18 +99,18 @@ class ModelTesterMixin:
self.model_tester.batch_size, dtype=torch.long, device=torch_device self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
elif model_class in [ elif model_class in [
*MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values(), *get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING),
*MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(), *get_values(MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING),
*MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING.values(), *get_values(MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING),
]: ]:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
elif model_class in [ elif model_class in [
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(), *get_values(MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING),
*MODEL_FOR_CAUSAL_LM_MAPPING.values(), *get_values(MODEL_FOR_CAUSAL_LM_MAPPING),
*MODEL_FOR_MASKED_LM_MAPPING.values(), *get_values(MODEL_FOR_MASKED_LM_MAPPING),
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(), *get_values(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING),
]: ]:
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
...@@ -229,7 +230,7 @@ class ModelTesterMixin: ...@@ -229,7 +230,7 @@ class ModelTesterMixin:
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values(): if model_class in get_values(MODEL_MAPPING):
continue continue
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
...@@ -248,7 +249,7 @@ class ModelTesterMixin: ...@@ -248,7 +249,7 @@ class ModelTesterMixin:
config.return_dict = True config.return_dict = True
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
if model_class in MODEL_MAPPING.values(): if model_class in get_values(MODEL_MAPPING):
continue continue
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
...@@ -312,7 +313,7 @@ class ModelTesterMixin: ...@@ -312,7 +313,7 @@ class ModelTesterMixin:
if "labels" in inputs_dict: if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits # Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs: if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned correct_outlen += 1 # past_key_values have been returned
......
...@@ -19,6 +19,7 @@ import unittest ...@@ -19,6 +19,7 @@ import unittest
from tests.test_modeling_common import floats_tensor from tests.test_modeling_common import floats_tensor
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -352,7 +353,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -352,7 +353,7 @@ class ConvBertModelTest(ModelTesterMixin, unittest.TestCase):
if "labels" in inputs_dict: if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits # Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs: if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned correct_outlen += 1 # past_key_values have been returned
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -292,7 +293,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -292,7 +293,7 @@ class ElectraModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
......
...@@ -29,6 +29,7 @@ if is_flax_available(): ...@@ -29,6 +29,7 @@ if is_flax_available():
FlaxBertForNextSentencePrediction, FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining, FlaxBertForPreTraining,
FlaxBertForQuestionAnswering, FlaxBertForQuestionAnswering,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification, FlaxBertForTokenClassification,
FlaxBertModel, FlaxBertModel,
) )
...@@ -125,6 +126,7 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase): ...@@ -125,6 +126,7 @@ class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
FlaxBertForMultipleChoice, FlaxBertForMultipleChoice,
FlaxBertForQuestionAnswering, FlaxBertForQuestionAnswering,
FlaxBertForNextSentencePrediction, FlaxBertForNextSentencePrediction,
FlaxBertForSequenceClassification,
FlaxBertForTokenClassification, FlaxBertForTokenClassification,
FlaxBertForQuestionAnswering, FlaxBertForQuestionAnswering,
) )
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import FunnelTokenizer, is_torch_available from transformers import FunnelTokenizer, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -365,7 +366,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -365,7 +366,7 @@ class FunnelModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
......
...@@ -21,6 +21,7 @@ import unittest ...@@ -21,6 +21,7 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.file_utils import cached_property from transformers.file_utils import cached_property
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -412,7 +413,7 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -412,7 +413,7 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
if "labels" in inputs_dict: if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning correct_outlen += 1 # loss is added to beginning
# Question Answering model returns start_logits and end_logits # Question Answering model returns start_logits and end_logits
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
correct_outlen += 1 # start_logits and end_logits instead of only 1 output correct_outlen += 1 # start_logits and end_logits instead of only 1 output
if "past_key_values" in outputs: if "past_key_values" in outputs:
correct_outlen += 1 # past_key_values have been returned correct_outlen += 1 # past_key_values have been returned
......
...@@ -18,6 +18,7 @@ import copy ...@@ -18,6 +18,7 @@ import copy
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -532,11 +533,11 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -532,11 +533,11 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = copy.deepcopy(inputs_dict) inputs_dict = copy.deepcopy(inputs_dict)
if return_labels: if return_labels:
if model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values(): if model_class in get_values(MODEL_FOR_QUESTION_ANSWERING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
self.model_tester.batch_size, dtype=torch.long, device=torch_device self.model_tester.batch_size, dtype=torch.long, device=torch_device
) )
elif model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): elif model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
# special case for models like BERT that use multi-loss training for PreTraining # special case for models like BERT that use multi-loss training for PreTraining
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
......
...@@ -21,6 +21,7 @@ import os ...@@ -21,6 +21,7 @@ import os
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -290,7 +291,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -290,7 +291,7 @@ class MegatronBertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import unittest import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
...@@ -272,7 +273,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -272,7 +273,7 @@ class MobileBertModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels) inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels: if return_labels:
if model_class in MODEL_FOR_PRETRAINING_MAPPING.values(): if model_class in get_values(MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["labels"] = torch.zeros( inputs_dict["labels"] = torch.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device (self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment