Unverified Commit 8f3b4a1d authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Little cleanup: let huggingface_hub manage token retrieval (#21333)

* Let huggingface_hub manage token retrieval

* flake8

* code quality

* adapt in every PushToHubMixin children

* add explicit return type
parent 0dff407d
......@@ -438,7 +438,7 @@ class PretrainedConfig(PushToHubMixin):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
......@@ -454,7 +454,11 @@ class PretrainedConfig(PushToHubMixin):
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
@classmethod
......
......@@ -22,7 +22,7 @@ import sys
from pathlib import Path
from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info
from huggingface_hub import model_info
from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging
......@@ -251,14 +251,7 @@ def get_cached_module_file(
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
token = HfFolder.get_token()
else:
token = None
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=token).sha
commit_hash = model_info(pretrained_model_name_or_path, revision=revision, token=use_auth_token).sha
# The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
# benefit of versioning.
......
......@@ -353,7 +353,7 @@ class FeatureExtractionMixin(PushToHubMixin):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
......@@ -369,7 +369,11 @@ class FeatureExtractionMixin(PushToHubMixin):
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
return [output_feature_extractor_file]
......
......@@ -337,7 +337,7 @@ class GenerationConfig(PushToHubMixin):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
output_config_file = os.path.join(save_directory, config_file_name)
......@@ -347,7 +347,11 @@ class GenerationConfig(PushToHubMixin):
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
@classmethod
......
......@@ -185,7 +185,7 @@ class ImageProcessingMixin(PushToHubMixin):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
......@@ -201,7 +201,11 @@ class ImageProcessingMixin(PushToHubMixin):
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
return [output_image_processor_file]
......
......@@ -1018,7 +1018,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# get abs dir
......@@ -1077,7 +1077,11 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
@classmethod
......
......@@ -2277,7 +2277,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
if saved_model:
......@@ -2363,7 +2363,11 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
@classmethod
......@@ -2946,7 +2950,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
else:
working_dir = repo_id.split("/")[-1]
repo_id, token = self._create_repo(
repo_id = self._create_repo(
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
)
......@@ -2968,7 +2972,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
self.create_model_card(**base_model_card_args)
self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=use_auth_token
)
@classmethod
......
......@@ -1633,7 +1633,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# Only save the model itself if we are using distributed training
......@@ -1717,7 +1717,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
def get_memory_footprint(self, return_buffers=True):
......
......@@ -121,7 +121,7 @@ class ProcessorMixin(PushToHubMixin):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
# If we have a custom config, we copy the file defining it in the folder and set the attributes so it can be
# loaded from the Hub.
......@@ -147,7 +147,11 @@ class ProcessorMixin(PushToHubMixin):
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
@classmethod
......
......@@ -2098,7 +2098,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id, token = self._create_repo(repo_id, **kwargs)
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)
special_tokens_map_file = os.path.join(
......@@ -2177,7 +2177,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=token
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("use_auth_token"),
)
return save_files
......
......@@ -31,7 +31,6 @@ import huggingface_hub
import requests
from huggingface_hub import (
CommitOperationAdd,
HfFolder,
create_commit,
create_repo,
get_hf_file_metadata,
......@@ -45,6 +44,7 @@ from huggingface_hub.utils import (
LocalEntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
build_hf_headers,
hf_raise_for_status,
)
from requests.exceptions import HTTPError
......@@ -583,7 +583,7 @@ def has_file(
use_auth_token: Optional[Union[bool, str]] = None,
):
"""
Checks if a repo contains a given file wihtout downloading it. Works for remote repos and local folders.
Checks if a repo contains a given file without downloading it. Works for remote repos and local folders.
<Tip warning={false}>
......@@ -596,15 +596,7 @@ def has_file(
return os.path.isfile(os.path.join(path_or_repo, filename))
url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
headers = {"user-agent": http_user_agent()}
if isinstance(use_auth_token, str):
headers["authorization"] = f"Bearer {use_auth_token}"
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = f"Bearer {token}"
headers = build_hf_headers(use_auth_token=use_auth_token, user_agent=http_user_agent())
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=10)
try:
......@@ -636,10 +628,10 @@ class PushToHubMixin:
use_auth_token: Optional[Union[bool, str]] = None,
repo_url: Optional[str] = None,
organization: Optional[str] = None,
):
) -> str:
"""
Create the repo if needed, cleans up repo_id with deprecated kwards `repo_url` and `organization`, retrives the
token.
Create the repo if needed, cleans up repo_id with deprecated kwargs `repo_url` and `organization`, retrieves
the token.
"""
if repo_url is not None:
warnings.warn(
......@@ -657,13 +649,12 @@ class PushToHubMixin:
repo_id = repo_id.split("/")[-1]
repo_id = f"{organization}/{repo_id}"
token = HfFolder.get_token() if use_auth_token is True else use_auth_token
url = create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True)
url = create_repo(repo_id=repo_id, token=use_auth_token, private=private, exist_ok=True)
# If the namespace is not there, add it or `upload_file` will complain
if "/" not in repo_id and url != f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/{repo_id}":
repo_id = get_full_repo_name(repo_id, token=token)
return repo_id, token
repo_id = get_full_repo_name(repo_id, token=use_auth_token)
return repo_id
def _get_files_timestamps(self, working_dir: Union[str, os.PathLike]):
"""
......@@ -677,7 +668,7 @@ class PushToHubMixin:
repo_id: str,
files_timestamps: Dict[str, float],
commit_message: Optional[str] = None,
token: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
):
"""
......@@ -776,7 +767,7 @@ class PushToHubMixin:
else:
working_dir = repo_id.split("/")[-1]
repo_id, token = self._create_repo(
repo_id = self._create_repo(
repo_id, private=private, use_auth_token=use_auth_token, repo_url=repo_url, organization=organization
)
......@@ -790,13 +781,16 @@ class PushToHubMixin:
self.save_pretrained(work_dir, max_shard_size=max_shard_size)
return self._upload_modified_files(
work_dir, repo_id, files_timestamps, commit_message=commit_message, token=token, create_pr=create_pr
work_dir,
repo_id,
files_timestamps,
commit_message=commit_message,
token=use_auth_token,
create_pr=create_pr,
)
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
......@@ -1040,8 +1034,6 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None):
cache_dir = str(old_cache)
else:
cache_dir = new_cache_dir
if token is None:
token = HfFolder.get_token()
cached_files = get_all_cached_files(cache_dir=cache_dir)
logger.info(f"Moving {len(cached_files)} files to the new cache system")
......@@ -1050,7 +1042,7 @@ def move_cache(cache_dir=None, new_cache_dir=None, token=None):
url = file_info.pop("url")
if url not in hub_metadata:
try:
hub_metadata[url] = get_hf_file_metadata(url, use_auth_token=token)
hub_metadata[url] = get_hf_file_metadata(url, token=token)
except requests.HTTPError:
continue
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment