Unverified Commit 6aadb8d0 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Allow existing configs to be registered (#24760)

parent 4c0e251d
......@@ -418,7 +418,7 @@ class _BaseAutoModelClass:
else:
repo_id = config.name_or_path
model_class = get_class_from_dynamic_module(class_ref, repo_id, **kwargs)
cls._model_mapping.register(config.__class__, model_class)
cls._model_mapping.register(config.__class__, model_class, exist_ok=True)
_ = kwargs.pop("code_revision", None)
return model_class._from_config(config, **kwargs)
elif type(config) in cls._model_mapping.keys():
......@@ -477,7 +477,7 @@ class _BaseAutoModelClass:
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs
)
_ = hub_kwargs.pop("code_revision", None)
cls._model_mapping.register(config.__class__, model_class)
cls._model_mapping.register(config.__class__, model_class, exist_ok=True)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
......@@ -492,7 +492,7 @@ class _BaseAutoModelClass:
)
@classmethod
def register(cls, config_class, model_class):
def register(cls, config_class, model_class, exist_ok=False):
"""
Register a new model for this class.
......@@ -508,7 +508,7 @@ class _BaseAutoModelClass:
f"config class you passed (model has {model_class.config_class} and you passed {config_class}. Fix "
"one of those so they match!"
)
cls._model_mapping.register(config_class, model_class)
cls._model_mapping.register(config_class, model_class, exist_ok=exist_ok)
class _BaseAutoBackboneClass(_BaseAutoModelClass):
......@@ -719,13 +719,13 @@ class _LazyAutoMapping(OrderedDict):
model_type = self._reverse_config_mapping[item.__name__]
return model_type in self._model_mapping
def register(self, key, value):
def register(self, key, value, exist_ok=False):
"""
Register a new model in this mapping.
"""
if hasattr(key, "__name__") and key.__name__ in self._reverse_config_mapping:
model_type = self._reverse_config_mapping[key.__name__]
if model_type in self._model_mapping.keys():
if model_type in self._model_mapping.keys() and not exist_ok:
raise ValueError(f"'{key}' is already used by a Transformers model.")
self._extra_content[key] = value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment