Unverified Commit ec0a945c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[AutoModels] Fix config params handling of all PT and TF AutoModels (#5665)

* fix auto model causal lm

* leverage given functionality

* apply unused kwargs to all auto models
parent 8ab565a4
......@@ -498,7 +498,9 @@ class AutoModel:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class):
......@@ -645,7 +647,9 @@ class AutoModelForPreTraining:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
if isinstance(config, config_class):
......@@ -802,7 +806,9 @@ class AutoModelWithLMHead:
)
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
if isinstance(config, config_class):
......@@ -937,7 +943,9 @@ class AutoModelForCausalLM:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class):
......@@ -1078,7 +1086,9 @@ class AutoModelForMaskedLM:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
if isinstance(config, config_class):
......@@ -1209,7 +1219,9 @@ class AutoModelForSeq2SeqLM:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class):
......@@ -1359,7 +1371,9 @@ class AutoModelForSequenceClassification:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class):
......@@ -1501,7 +1515,9 @@ class AutoModelForQuestionAnswering:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
if isinstance(config, config_class):
......@@ -1651,7 +1667,9 @@ class AutoModelForTokenClassification:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class):
......@@ -1703,7 +1721,9 @@ class AutoModelForMultipleChoice:
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
if isinstance(config, config_class):
......
......@@ -450,7 +450,9 @@ class TFAutoModel(object):
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_MAPPING.items():
if isinstance(config, config_class):
......@@ -601,7 +603,9 @@ class TFAutoModelForPreTraining(object):
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items():
if isinstance(config, config_class):
......@@ -776,7 +780,9 @@ class TFAutoModelWithLMHead(object):
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
# Not using isinstance() here to do not take into account inheritance
......@@ -923,7 +929,9 @@ class TFAutoModelForMultipleChoice:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
if isinstance(config, config_class):
......@@ -1058,7 +1066,9 @@ class TFAutoModelForCausalLM:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class):
......@@ -1198,7 +1208,9 @@ class TFAutoModelForMaskedLM:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items():
if isinstance(config, config_class):
......@@ -1323,7 +1335,9 @@ class TFAutoModelForSeq2SeqLM:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class):
......@@ -1482,7 +1496,9 @@ class TFAutoModelForSequenceClassification(object):
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class):
......@@ -1644,7 +1660,9 @@ class TFAutoModelForQuestionAnswering(object):
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
if isinstance(config, config_class):
......@@ -1775,7 +1793,9 @@ class TFAutoModelForTokenClassification:
"""
config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment