"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "698c9e2dbdbc35aed588b58f080afbdbfa0c3c04"
Unverified Commit 6aadb8d0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Allow existing configs to be registered (#24760)

parent 4c0e251d
...@@ -418,7 +418,7 @@ class _BaseAutoModelClass: ...@@ -418,7 +418,7 @@ class _BaseAutoModelClass:
else: else:
repo_id = config.name_or_path repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs) model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
cls._model_mapping.register(config.__class__, model_class) cls._model_mapping.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None) _ = kwargs.pop("code_revision", None)
return model_class._from_config(config, **kwargs) return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys(): elif type(config) in cls._model_mapping.keys():
...@@ -477,7 +477,7 @@ class _BaseAutoModelClass: ...@@ -477,7 +477,7 @@ class _BaseAutoModelClass:
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
) )
_ = hub_kwargs.pop("code_revision", None) _ = hub_kwargs.pop("code_revision", None)
cls._model_mapping.register(config.__class__, model_class) cls._model_mapping.register(config.__class__, model_class, exist_ok=True)
return model_class.from_pretrained( return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
) )
...@@ -492,7 +492,7 @@ class _BaseAutoModelClass: ...@@ -492,7 +492,7 @@ class _BaseAutoModelClass:
) )
@classmethod @classmethod
def register(cls, config_class, model_class): def register(cls, config_class, model_class, exist_ok=False):
""" """
Register a new model for this class. Register a new model for this class.
...@@ -508,7 +508,7 @@ class _BaseAutoModelClass: ...@@ -508,7 +508,7 @@ class _BaseAutoModelClass:
f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix " f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
"one of those so they match!" "one of those so they match!"
) )
cls._model_mapping.register(config_class, model_class) cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
class _BaseAutoBackboneClass(_BaseAutoModelClass): class _BaseAutoBackboneClass(_BaseAutoModelClass):
...@@ -719,13 +719,13 @@ class _LazyAutoMapping(OrderedDict): ...@@ -719,13 +719,13 @@ class _LazyAutoMapping(OrderedDict):
model_type = self._reverse_config_mapping[item.__name__] model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping return model_type in self._model_mapping
def register(self, key, value): def register(self, key, value, exist_ok=False):
""" """
Register a new model in this mapping. Register a new model in this mapping.
""" """
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping: if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
model_type = self._reverse_config_mapping[key.__name__] model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping.keys(): if model_type in self._model_mapping.keys() and not exist_ok:
raise ValueError(f"'{key}' is already used by a Transformers model.") raise ValueError(f"'{key}' is already used by a Transformers model.")
self._extra_content[key] = value self._extra_content[key] = value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment