Unverified Commit 7cdd9da5 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Check config type using `type` instead of `isinstance` (#7363)

* Check config type instead of instance


Bad merge

* Remove for loops

* Style
parent 3c6bf899
...@@ -544,9 +544,8 @@ class AutoModel: ...@@ -544,9 +544,8 @@ class AutoModel:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModel.from_config(config) >>> model = AutoModel.from_config(config)
""" """
for config_class, model_class in MODEL_MAPPING.items(): if type(config) in MODEL_MAPPING.keys():
if isinstance(config, config_class): return MODEL_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -585,9 +584,10 @@ class AutoModel: ...@@ -585,9 +584,10 @@ class AutoModel:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_MAPPING.items(): if type(config) in MODEL_MAPPING.keys():
if isinstance(config, config_class): return MODEL_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -638,9 +638,8 @@ class AutoModelForPreTraining: ...@@ -638,9 +638,8 @@ class AutoModelForPreTraining:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForPreTraining.from_config(config) >>> model = AutoModelForPreTraining.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items(): if type(config) in MODEL_FOR_PRETRAINING_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_PRETRAINING_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -679,9 +678,10 @@ class AutoModelForPreTraining: ...@@ -679,9 +678,10 @@ class AutoModelForPreTraining:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_PRETRAINING_MAPPING.items(): if type(config) in MODEL_FOR_PRETRAINING_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_PRETRAINING_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -744,9 +744,8 @@ class AutoModelWithLMHead: ...@@ -744,9 +744,8 @@ class AutoModelWithLMHead:
"`AutoModelForSeq2SeqLM` for encoder-decoder models.", "`AutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning, FutureWarning,
) )
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items(): if type(config) in MODEL_WITH_LM_HEAD_MAPPING.keys():
if isinstance(config, config_class): return MODEL_WITH_LM_HEAD_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -791,9 +790,10 @@ class AutoModelWithLMHead: ...@@ -791,9 +790,10 @@ class AutoModelWithLMHead:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_WITH_LM_HEAD_MAPPING.items(): if type(config) in MODEL_WITH_LM_HEAD_MAPPING.keys():
if isinstance(config, config_class): return MODEL_WITH_LM_HEAD_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -844,9 +844,8 @@ class AutoModelForCausalLM: ...@@ -844,9 +844,8 @@ class AutoModelForCausalLM:
>>> config = AutoConfig.from_pretrained('gpt2') >>> config = AutoConfig.from_pretrained('gpt2')
>>> model = AutoModelForCausalLM.from_config(config) >>> model = AutoModelForCausalLM.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items(): if type(config) in MODEL_FOR_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_CAUSAL_LM_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -885,9 +884,10 @@ class AutoModelForCausalLM: ...@@ -885,9 +884,10 @@ class AutoModelForCausalLM:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_CAUSAL_LM_MAPPING.items(): if type(config) in MODEL_FOR_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_CAUSAL_LM_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -938,9 +938,8 @@ class AutoModelForMaskedLM: ...@@ -938,9 +938,8 @@ class AutoModelForMaskedLM:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForMaskedLM.from_config(config) >>> model = AutoModelForMaskedLM.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items(): if type(config) in MODEL_FOR_MASKED_LM_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_MASKED_LM_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -979,9 +978,10 @@ class AutoModelForMaskedLM: ...@@ -979,9 +978,10 @@ class AutoModelForMaskedLM:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_MASKED_LM_MAPPING.items(): if type(config) in MODEL_FOR_MASKED_LM_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_MASKED_LM_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1032,9 +1032,8 @@ class AutoModelForSeq2SeqLM: ...@@ -1032,9 +1032,8 @@ class AutoModelForSeq2SeqLM:
>>> config = AutoConfig.from_pretrained('t5') >>> config = AutoConfig.from_pretrained('t5')
>>> model = AutoModelForSeq2SeqLM.from_config(config) >>> model = AutoModelForSeq2SeqLM.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): if type(config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1075,9 +1074,10 @@ class AutoModelForSeq2SeqLM: ...@@ -1075,9 +1074,10 @@ class AutoModelForSeq2SeqLM:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): if type(config) in MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1130,9 +1130,8 @@ class AutoModelForSequenceClassification: ...@@ -1130,9 +1130,8 @@ class AutoModelForSequenceClassification:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForSequenceClassification.from_config(config) >>> model = AutoModelForSequenceClassification.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): if type(config) in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1173,9 +1172,10 @@ class AutoModelForSequenceClassification: ...@@ -1173,9 +1172,10 @@ class AutoModelForSequenceClassification:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): if type(config) in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1227,9 +1227,8 @@ class AutoModelForQuestionAnswering: ...@@ -1227,9 +1227,8 @@ class AutoModelForQuestionAnswering:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForQuestionAnswering.from_config(config) >>> model = AutoModelForQuestionAnswering.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): if type(config) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
...@@ -1271,9 +1270,10 @@ class AutoModelForQuestionAnswering: ...@@ -1271,9 +1270,10 @@ class AutoModelForQuestionAnswering:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): if type(config) in MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
...@@ -1326,9 +1326,8 @@ class AutoModelForTokenClassification: ...@@ -1326,9 +1326,8 @@ class AutoModelForTokenClassification:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForTokenClassification.from_config(config) >>> model = AutoModelForTokenClassification.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): if type(config) in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
...@@ -1370,9 +1369,10 @@ class AutoModelForTokenClassification: ...@@ -1370,9 +1369,10 @@ class AutoModelForTokenClassification:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): if type(config) in MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
...@@ -1426,9 +1426,8 @@ class AutoModelForMultipleChoice: ...@@ -1426,9 +1426,8 @@ class AutoModelForMultipleChoice:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = AutoModelForMultipleChoice.from_config(config) >>> model = AutoModelForMultipleChoice.from_config(config)
""" """
for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): if type(config) in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
...@@ -1470,9 +1469,10 @@ class AutoModelForMultipleChoice: ...@@ -1470,9 +1469,10 @@ class AutoModelForMultipleChoice:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): if type(config) in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys():
if isinstance(config, config_class): return MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of AutoModel: {}.\n" "Unrecognized configuration class {} for this kind of AutoModel: {}.\n"
......
...@@ -453,9 +453,8 @@ class TFAutoModel(object): ...@@ -453,9 +453,8 @@ class TFAutoModel(object):
>>> config = TFAutoConfig.from_pretrained('bert-base-uncased') >>> config = TFAutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModel.from_config(config) >>> model = TFAutoModel.from_config(config)
""" """
for config_class, model_class in TF_MODEL_MAPPING.items(): if type(config) in TF_MODEL_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -494,9 +493,10 @@ class TFAutoModel(object): ...@@ -494,9 +493,10 @@ class TFAutoModel(object):
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_MAPPING.items(): if type(config) in TF_MODEL_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -547,9 +547,8 @@ class TFAutoModelForPreTraining(object): ...@@ -547,9 +547,8 @@ class TFAutoModelForPreTraining(object):
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForPreTraining.from_config(config) >>> model = TFAutoModelForPreTraining.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items(): if type(config) in TF_MODEL_FOR_PRETRAINING_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_PRETRAINING_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -588,9 +587,10 @@ class TFAutoModelForPreTraining(object): ...@@ -588,9 +587,10 @@ class TFAutoModelForPreTraining(object):
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.items(): if type(config) in TF_MODEL_FOR_PRETRAINING_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_PRETRAINING_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -653,9 +653,8 @@ class TFAutoModelWithLMHead(object): ...@@ -653,9 +653,8 @@ class TFAutoModelWithLMHead(object):
"and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.", "and `TFAutoModelForSeq2SeqLM` for encoder-decoder models.",
FutureWarning, FutureWarning,
) )
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items(): if type(config) in TF_MODEL_WITH_LM_HEAD_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_WITH_LM_HEAD_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -701,10 +700,10 @@ class TFAutoModelWithLMHead(object): ...@@ -701,10 +700,10 @@ class TFAutoModelWithLMHead(object):
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_WITH_LM_HEAD_MAPPING.items(): if type(config) in TF_MODEL_WITH_LM_HEAD_MAPPING.keys():
# Not using isinstance() here to do not take into account inheritance return TF_MODEL_WITH_LM_HEAD_MAPPING[type(config)].from_pretrained(
if config_class == type(config): pretrained_model_name_or_path, *model_args, config=config, **kwargs
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) )
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -755,9 +754,8 @@ class TFAutoModelForCausalLM: ...@@ -755,9 +754,8 @@ class TFAutoModelForCausalLM:
>>> config = AutoConfig.from_pretrained('gpt2') >>> config = AutoConfig.from_pretrained('gpt2')
>>> model = TFAutoModelForCausalLM.from_config(config) >>> model = TFAutoModelForCausalLM.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items(): if type(config) in TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_CAUSAL_LM_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -796,9 +794,10 @@ class TFAutoModelForCausalLM: ...@@ -796,9 +794,10 @@ class TFAutoModelForCausalLM:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.items(): if type(config) in TF_MODEL_FOR_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_CAUSAL_LM_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -849,9 +848,8 @@ class TFAutoModelForMaskedLM: ...@@ -849,9 +848,8 @@ class TFAutoModelForMaskedLM:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForMaskedLM.from_config(config) >>> model = TFAutoModelForMaskedLM.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items(): if type(config) in TF_MODEL_FOR_MASKED_LM_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_MASKED_LM_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -890,9 +888,10 @@ class TFAutoModelForMaskedLM: ...@@ -890,9 +888,10 @@ class TFAutoModelForMaskedLM:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.items(): if type(config) in TF_MODEL_FOR_MASKED_LM_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_MASKED_LM_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -943,9 +942,8 @@ class TFAutoModelForSeq2SeqLM: ...@@ -943,9 +942,8 @@ class TFAutoModelForSeq2SeqLM:
>>> config = AutoConfig.from_pretrained('t5') >>> config = AutoConfig.from_pretrained('t5')
>>> model = TFAutoModelForSeq2SeqLM.from_config(config) >>> model = TFAutoModelForSeq2SeqLM.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -986,9 +984,10 @@ class TFAutoModelForSeq2SeqLM: ...@@ -986,9 +984,10 @@ class TFAutoModelForSeq2SeqLM:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.items(): if type(config) in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1041,9 +1040,8 @@ class TFAutoModelForSequenceClassification(object): ...@@ -1041,9 +1040,8 @@ class TFAutoModelForSequenceClassification(object):
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForSequenceClassification.from_config(config) >>> model = TFAutoModelForSequenceClassification.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): if type(config) in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1084,9 +1082,10 @@ class TFAutoModelForSequenceClassification(object): ...@@ -1084,9 +1082,10 @@ class TFAutoModelForSequenceClassification(object):
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.items(): if type(config) in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1138,9 +1137,8 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -1138,9 +1137,8 @@ class TFAutoModelForQuestionAnswering(object):
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForQuestionAnswering.from_config(config) >>> model = TFAutoModelForQuestionAnswering.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): if type(config) in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1181,9 +1179,10 @@ class TFAutoModelForQuestionAnswering(object): ...@@ -1181,9 +1179,10 @@ class TFAutoModelForQuestionAnswering(object):
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.items(): if type(config) in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1235,9 +1234,8 @@ class TFAutoModelForTokenClassification: ...@@ -1235,9 +1234,8 @@ class TFAutoModelForTokenClassification:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForTokenClassification.from_config(config) >>> model = TFAutoModelForTokenClassification.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): if type(config) in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1278,9 +1276,10 @@ class TFAutoModelForTokenClassification: ...@@ -1278,9 +1276,10 @@ class TFAutoModelForTokenClassification:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.items(): if type(config) in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1333,9 +1332,8 @@ class TFAutoModelForMultipleChoice: ...@@ -1333,9 +1332,8 @@ class TFAutoModelForMultipleChoice:
>>> config = AutoConfig.from_pretrained('bert-base-uncased') >>> config = AutoConfig.from_pretrained('bert-base-uncased')
>>> model = TFAutoModelForMultipleChoice.from_config(config) >>> model = TFAutoModelForMultipleChoice.from_config(config)
""" """
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): if type(config) in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)](config)
return model_class(config)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
...@@ -1376,9 +1374,10 @@ class TFAutoModelForMultipleChoice: ...@@ -1376,9 +1374,10 @@ class TFAutoModelForMultipleChoice:
pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs
) )
for config_class, model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.items(): if type(config) in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.keys():
if isinstance(config, config_class): return TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING[type(config)].from_pretrained(
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) pretrained_model_name_or_path, *model_args, config=config, **kwargs
)
raise ValueError( raise ValueError(
"Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n" "Unrecognized configuration class {} for this kind of TFAutoModel: {}.\n"
"Model type should be one of {}.".format( "Model type should be one of {}.".format(
......
...@@ -243,8 +243,8 @@ class AutoTokenizer: ...@@ -243,8 +243,8 @@ class AutoTokenizer:
) )
config = config.encoder config = config.encoder
for config_class, (tokenizer_class_py, tokenizer_class_fast) in TOKENIZER_MAPPING.items(): if type(config) in TOKENIZER_MAPPING.keys():
if type(config) is config_class: tokenizer_class_py, tokenizer_class_fast = TOKENIZER_MAPPING[type(config)]
if tokenizer_class_fast and use_fast: if tokenizer_class_fast and use_fast:
return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs) return tokenizer_class_fast.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment