Unverified Commit d7e2d7b4 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Preserve hub-related kwargs in AutoModel.from_pretrained (#18545)

* Preserve hub-related kwargs in AutoModel.from_pretrained

* Fix tests

* Remove debug statement
parent 34aad0da
...@@ -419,9 +419,24 @@ class _BaseAutoModelClass: ...@@ -419,9 +419,24 @@ class _BaseAutoModelClass:
config = kwargs.pop("config", None) config = kwargs.pop("config", None)
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", False)
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
hub_kwargs_names = [
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"resume_download",
"revision",
"subfolder",
"use_auth_token",
]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
if not isinstance(config, PretrainedConfig): if not isinstance(config, PretrainedConfig):
config, kwargs = AutoConfig.from_pretrained( config, kwargs = AutoConfig.from_pretrained(
pretrained_model_name_or_path, return_unused_kwargs=True, trust_remote_code=trust_remote_code, **kwargs pretrained_model_name_or_path,
return_unused_kwargs=True,
trust_remote_code=trust_remote_code,
**hub_kwargs,
**kwargs,
) )
if hasattr(config, "auto_map") and cls.__name__ in config.auto_map: if hasattr(config, "auto_map") and cls.__name__ in config.auto_map:
if not trust_remote_code: if not trust_remote_code:
...@@ -430,7 +445,7 @@ class _BaseAutoModelClass: ...@@ -430,7 +445,7 @@ class _BaseAutoModelClass:
"on your local machine. Make sure you have read the code there to avoid malicious use, then set " "on your local machine. Make sure you have read the code there to avoid malicious use, then set "
"the option `trust_remote_code=True` to remove this error." "the option `trust_remote_code=True` to remove this error."
) )
if kwargs.get("revision", None) is None: if hub_kwargs.get("revision", None) is None:
logger.warning( logger.warning(
"Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure " "Explicitly passing a `revision` is encouraged when loading a model with custom code to ensure "
"no malicious code has been contributed in a newer revision." "no malicious code has been contributed in a newer revision."
...@@ -438,12 +453,16 @@ class _BaseAutoModelClass: ...@@ -438,12 +453,16 @@ class _BaseAutoModelClass:
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
module_file, class_name = class_ref.split(".") module_file, class_name = class_ref.split(".")
model_class = get_class_from_dynamic_module( model_class = get_class_from_dynamic_module(
pretrained_model_name_or_path, module_file + ".py", class_name, **kwargs pretrained_model_name_or_path, module_file + ".py", class_name, **hub_kwargs, **kwargs
)
return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
) )
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs)
elif type(config) in cls._model_mapping.keys(): elif type(config) in cls._model_mapping.keys():
model_class = _get_model_class(config, cls._model_mapping) model_class = _get_model_class(config, cls._model_mapping)
return model_class.from_pretrained(pretrained_model_name_or_path, *model_args, config=config, **kwargs) return model_class.from_pretrained(
pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
)
raise ValueError( raise ValueError(
f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n" f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}." f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping.keys())}."
......
...@@ -728,7 +728,7 @@ class AutoConfig: ...@@ -728,7 +728,7 @@ class AutoConfig:
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
kwargs["name_or_path"] = pretrained_model_name_or_path kwargs["name_or_path"] = pretrained_model_name_or_path
trust_remote_code = kwargs.pop("trust_remote_code", False) trust_remote_code = kwargs.pop("trust_remote_code", False)
config_dict, _ = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]: if "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]:
if not trust_remote_code: if not trust_remote_code:
raise ValueError( raise ValueError(
...@@ -749,13 +749,13 @@ class AutoConfig: ...@@ -749,13 +749,13 @@ class AutoConfig:
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict: elif "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]] config_class = CONFIG_MAPPING[config_dict["model_type"]]
return config_class.from_dict(config_dict, **kwargs) return config_class.from_dict(config_dict, **unused_kwargs)
else: else:
# Fallback: use pattern matching on the string. # Fallback: use pattern matching on the string.
# We go from longer names to shorter names to catch roberta before bert (for instance) # We go from longer names to shorter names to catch roberta before bert (for instance)
for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True): for pattern in sorted(CONFIG_MAPPING.keys(), key=len, reverse=True):
if pattern in str(pretrained_model_name_or_path): if pattern in str(pretrained_model_name_or_path):
return CONFIG_MAPPING[pattern].from_dict(config_dict, **kwargs) return CONFIG_MAPPING[pattern].from_dict(config_dict, **unused_kwargs)
raise ValueError( raise ValueError(
f"Unrecognized model in {pretrained_model_name_or_path}. " f"Unrecognized model in {pretrained_model_name_or_path}. "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment