"docs/vscode:/vscode.git/clone" did not exist on "7a7fdf71f80452fcae064bd016f06e9a0f0f19ed"
Unverified Commit ec0a945c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[AutoModels] Fix config params handling of all PT and TF AutoModels (#5665)

* fix auto model causal lm

* leverage given functionality

* apply unused kwargs to all auto models
parent 8ab565a4
...@@ -498,7 +498,9 @@ class AutoModel: ...@@ -498,7 +498,9 @@ class AutoModel:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_MAPPING.items(): for config_class, model_class in MODEL_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -645,7 +647,9 @@ class AutoModelForPreTraining: ...@@ -645,7 +647,9 @@ class AutoModelForPreTraining:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items(): for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -802,7 +806,9 @@ class AutoModelWithLMHead: ...@@ -802,7 +806,9 @@ class AutoModelWithLMHead:
) )
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items(): for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -937,7 +943,9 @@ class AutoModelForCausalLM: ...@@ -937,7 +943,9 @@ class AutoModelForCausalLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items(): for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1078,7 +1086,9 @@ class AutoModelForMaskedLM: ...@@ -1078,7 +1086,9 @@ class AutoModelForMaskedLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items(): for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1209,7 +1219,9 @@ class AutoModelForSeq2SeqLM: ...@@ -1209,7 +1219,9 @@ class AutoModelForSeq2SeqLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1359,7 +1371,9 @@ class AutoModelForSequenceClassification: ...@@ -1359,7 +1371,9 @@ class AutoModelForSequenceClassification:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1501,7 +1515,9 @@ class AutoModelForQuestionAnswering: ...@@ -1501,7 +1515,9 @@ class AutoModelForQuestionAnswering:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1651,7 +1667,9 @@ class AutoModelForTokenClassification: ...@@ -1651,7 +1667,9 @@ class AutoModelForTokenClassification:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1703,7 +1721,9 @@ class AutoModelForMultipleChoice: ...@@ -1703,7 +1721,9 @@ class AutoModelForMultipleChoice:
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
......
...@@ -450,7 +450,9 @@ class TFAutoModel(object): ...@@ -450,7 +450,9 @@ class TFAutoModel(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_MAPPING.items(): for config_class, model_class in TF_MODEL_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -601,7 +603,9 @@ class TFAutoModelForPreTraining(object): ...@@ -601,7 +603,9 @@ class TFAutoModelForPreTraining(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -776,7 +780,9 @@ class TFAutoModelWithLMHead(object): ...@@ -776,7 +780,9 @@ class TFAutoModelWithLMHead(object):
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items(): for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items():
# Not using isinstance() here to do not take into account inheritance # Not using isinstance() here to do not take into account inheritance
...@@ -923,7 +929,9 @@ class TFAutoModelForMultipleChoice: ...@@ -923,7 +929,9 @@ class TFAutoModelForMultipleChoice:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1058,7 +1066,9 @@ class TFAutoModelForCausalLM: ...@@ -1058,7 +1066,9 @@ class TFAutoModelForCausalLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1198,7 +1208,9 @@ class TFAutoModelForMaskedLM: ...@@ -1198,7 +1208,9 @@ class TFAutoModelForMaskedLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1323,7 +1335,9 @@ class TFAutoModelForSeq2SeqLM: ...@@ -1323,7 +1335,9 @@ class TFAutoModelForSeq2SeqLM:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1482,7 +1496,9 @@ class TFAutoModelForSequenceClassification(object): ...@@ -1482,7 +1496,9 @@ class TFAutoModelForSequenceClassification(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1644,7 +1660,9 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -1644,7 +1660,9 @@ class TFAutoModelForQuestionAnswering(object):
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
...@@ -1775,7 +1793,9 @@ class TFAutoModelForTokenClassification: ...@@ -1775,7 +1793,9 @@ class TFAutoModelForTokenClassification:
""" """
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
)
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items():
if isinstance(config, config_class): if isinstance(config, config_class):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment