Unverified Commit 2febd506 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix number of minimal calls to the Hub with peft integration (#25715)

* Fix number of minimal calls to the Hub with peft integration

* Alternate design

* And this way?

* Revert

* Address comments
parent 70b49f02
...@@ -51,6 +51,7 @@ from .pytorch_utils import ( # noqa: F401 ...@@ -51,6 +51,7 @@ from .pytorch_utils import ( # noqa: F401
from .utils import ( from .utils import (
ADAPTER_SAFE_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME,
ADAPTER_WEIGHTS_NAME, ADAPTER_WEIGHTS_NAME,
CONFIG_NAME,
DUMMY_INPUTS, DUMMY_INPUTS,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME,
...@@ -65,6 +66,7 @@ from .utils import ( ...@@ -65,6 +66,7 @@ from .utils import (
cached_file, cached_file,
copy_func, copy_func,
download_url, download_url,
extract_commit_hash,
has_file, has_file,
is_accelerate_available, is_accelerate_available,
is_auto_gptq_available, is_auto_gptq_available,
...@@ -2368,13 +2370,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2368,13 +2370,39 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" ignored." " ignored."
) )
if commit_hash is None:
if not isinstance(config, PretrainedConfig):
# We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
else:
commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available() and _adapter_model_path is None: if is_peft_available() and _adapter_model_path is None:
maybe_adapter_model_path = find_adapter_config_file( maybe_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path, pretrained_model_name_or_path,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
token=token, _commit_hash=commit_hash,
commit_hash=commit_hash,
) )
elif is_peft_available() and _adapter_model_path is not None: elif is_peft_available() and _adapter_model_path is not None:
maybe_adapter_model_path = _adapter_model_path maybe_adapter_model_path = _adapter_model_path
...@@ -2622,9 +2650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2622,9 +2650,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
" `pip install --upgrade bitsandbytes`." " `pip install --upgrade bitsandbytes`."
) )
if commit_hash is None:
commit_hash = getattr(config, "_commit_hash", None)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# index of the files. # index of the files.
is_sharded = False is_sharded = False
......
...@@ -22,7 +22,16 @@ from collections import OrderedDict ...@@ -22,7 +22,16 @@ from collections import OrderedDict
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ...utils import copy_func, find_adapter_config_file, is_peft_available, logging, requires_backends from ...utils import (
CONFIG_NAME,
cached_file,
copy_func,
extract_commit_hash,
find_adapter_config_file,
is_peft_available,
logging,
requires_backends,
)
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
...@@ -443,7 +452,6 @@ class _BaseAutoModelClass: ...@@ -443,7 +452,6 @@ class _BaseAutoModelClass:
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
hub_kwargs_names = [ hub_kwargs_names = [
"cache_dir", "cache_dir",
"code_revision",
"force_download", "force_download",
"local_files_only", "local_files_only",
"proxies", "proxies",
...@@ -454,6 +462,8 @@ class _BaseAutoModelClass: ...@@ -454,6 +462,8 @@ class _BaseAutoModelClass:
"token", "token",
] ]
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs} hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
code_revision = kwargs.pop("code_revision", None)
commit_hash = kwargs.pop("_commit_hash", None)
token = hub_kwargs.pop("token", None) token = hub_kwargs.pop("token", None)
use_auth_token = hub_kwargs.pop("use_auth_token", None) use_auth_token = hub_kwargs.pop("use_auth_token", None)
...@@ -470,12 +480,23 @@ class _BaseAutoModelClass: ...@@ -470,12 +480,23 @@ class _BaseAutoModelClass:
if token is not None: if token is not None:
hub_kwargs["token"] = token hub_kwargs["token"] = token
if is_peft_available(): if commit_hash is None:
revision = kwargs.get("revision", None) if not isinstance(config, PretrainedConfig):
subfolder = kwargs.get("subfolder", None) # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
resolved_config_file = cached_file(
pretrained_model_name_or_path,
CONFIG_NAME,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
**hub_kwargs,
)
commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
else:
commit_hash = getattr(config, "_commit_hash", None)
if is_peft_available():
maybe_adapter_path = find_adapter_config_file( maybe_adapter_path = find_adapter_config_file(
pretrained_model_name_or_path, revision=revision, token=token, subfolder=subfolder pretrained_model_name_or_path, _commit_hash=commit_hash, **hub_kwargs
) )
if maybe_adapter_path is not None: if maybe_adapter_path is not None:
...@@ -499,6 +520,8 @@ class _BaseAutoModelClass: ...@@ -499,6 +520,8 @@ class _BaseAutoModelClass:
pretrained_model_name_or_path, pretrained_model_name_or_path,
return_unused_kwargs=True, return_unused_kwargs=True,
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
code_revision=code_revision,
_commit_hash=commit_hash,
**hub_kwargs, **hub_kwargs,
**kwargs, **kwargs,
) )
...@@ -517,7 +540,7 @@ class _BaseAutoModelClass: ...@@ -517,7 +540,7 @@ class _BaseAutoModelClass:
if has_remote_code and trust_remote_code: if has_remote_code and trust_remote_code:
class_ref = config.auto_map[cls.__name__] class_ref = config.auto_map[cls.__name__]
model_class = get_class_from_dynamic_module( model_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, **hub_kwargs, **kwargs class_ref, pretrained_model_name_or_path, code_revision=code_revision, **hub_kwargs, **kwargs
) )
_ = hub_kwargs.pop("code_revision", None) _ = hub_kwargs.pop("code_revision", None)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
......
...@@ -1007,6 +1007,8 @@ class AutoConfig: ...@@ -1007,6 +1007,8 @@ class AutoConfig:
kwargs["_from_auto"] = True kwargs["_from_auto"] = True
kwargs["name_or_path"] = pretrained_model_name_or_path kwargs["name_or_path"] = pretrained_model_name_or_path
trust_remote_code = kwargs.pop("trust_remote_code", None) trust_remote_code = kwargs.pop("trust_remote_code", None)
code_revision = kwargs.pop("code_revision", None)
config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs) config_dict, unused_kwargs = PretrainedConfig.get_config_dict(pretrained_model_name_or_path, **kwargs)
has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"] has_remote_code = "auto_map" in config_dict and "AutoConfig" in config_dict["auto_map"]
has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING has_local_code = "model_type" in config_dict and config_dict["model_type"] in CONFIG_MAPPING
...@@ -1016,10 +1018,11 @@ class AutoConfig: ...@@ -1016,10 +1018,11 @@ class AutoConfig:
if has_remote_code and trust_remote_code: if has_remote_code and trust_remote_code:
class_ref = config_dict["auto_map"]["AutoConfig"] class_ref = config_dict["auto_map"]["AutoConfig"]
config_class = get_class_from_dynamic_module(class_ref, pretrained_model_name_or_path, **kwargs) config_class = get_class_from_dynamic_module(
class_ref, pretrained_model_name_or_path, code_revision=code_revision, **kwargs
)
if os.path.isdir(pretrained_model_name_or_path): if os.path.isdir(pretrained_model_name_or_path):
config_class.register_for_auto_class() config_class.register_for_auto_class()
_ = kwargs.pop("code_revision", None)
return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return config_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
elif "model_type" in config_dict: elif "model_type" in config_dict:
config_class = CONFIG_MAPPING[config_dict["model_type"]] config_class = CONFIG_MAPPING[config_dict["model_type"]]
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import importlib import importlib
import os import os
from typing import Optional from typing import Dict, Optional, Union
from packaging import version from packaging import version
...@@ -28,10 +28,15 @@ ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors" ...@@ -28,10 +28,15 @@ ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
def find_adapter_config_file( def find_adapter_config_file(
model_id: str, model_id: str,
revision: str = None, cache_dir: Optional[Union[str, os.PathLike]] = None,
subfolder: str = None, force_download: bool = False,
token: Optional[str] = None, resume_download: bool = False,
commit_hash: Optional[str] = None, proxies: Optional[Dict[str, str]] = None,
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
local_files_only: bool = False,
subfolder: str = "",
_commit_hash: Optional[str] = None,
) -> Optional[str]: ) -> Optional[str]:
r""" r"""
Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the adapter Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the adapter
...@@ -40,6 +45,20 @@ def find_adapter_config_file( ...@@ -40,6 +45,20 @@ def find_adapter_config_file(
Args: Args:
model_id (`str`): model_id (`str`):
The identifier of the model to look for, can be either a local path or an id to the repository on the Hub. The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if they
exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`): revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
...@@ -51,12 +70,11 @@ def find_adapter_config_file( ...@@ -51,12 +70,11 @@ def find_adapter_config_file(
</Tip> </Tip>
local_files_only (`bool`, *optional*, defaults to `False`):
If `True`, will only try to load the tokenizer configuration from local files.
subfolder (`str`, *optional*, defaults to `""`): subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
specify the folder name here. specify the folder name here.
token (`str`, `optional`):
Whether to use authentication token to load the remote folder. Userful to load private repositories that
are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to cache it.
""" """
adapter_cached_filename = None adapter_cached_filename = None
if model_id is None: if model_id is None:
...@@ -69,10 +87,15 @@ def find_adapter_config_file( ...@@ -69,10 +87,15 @@ def find_adapter_config_file(
adapter_cached_filename = cached_file( adapter_cached_filename = cached_file(
model_id, model_id,
ADAPTER_CONFIG_NAME, ADAPTER_CONFIG_NAME,
revision=revision, cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token, token=token,
_commit_hash=commit_hash, revision=revision,
local_files_only=local_files_only,
subfolder=subfolder, subfolder=subfolder,
_commit_hash=_commit_hash,
_raise_exceptions_for_missing_entries=False, _raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False, _raise_exceptions_for_connection_errors=False,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment