Unverified Commit d2a980ec authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Harmonize HF environment variables + other cleaning (#27564)

* Harmonize HF environment variables + other cleaning

* backward compat

* switch from HUGGINGFACE_HUB_CACHE to HF_HUB_CACHE

* revert
parent 7f043738
...@@ -25,6 +25,8 @@ import warnings ...@@ -25,6 +25,8 @@ import warnings
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from huggingface_hub import try_to_load_from_cache
from .utils import ( from .utils import (
HF_MODULES_CACHE, HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME, TRANSFORMERS_DYNAMIC_MODULE_NAME,
...@@ -32,7 +34,6 @@ from .utils import ( ...@@ -32,7 +34,6 @@ from .utils import (
extract_commit_hash, extract_commit_hash,
is_offline_mode, is_offline_mode,
logging, logging,
try_to_load_from_cache,
) )
......
...@@ -31,13 +31,16 @@ from uuid import uuid4 ...@@ -31,13 +31,16 @@ from uuid import uuid4
import huggingface_hub import huggingface_hub
import requests import requests
from huggingface_hub import ( from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd, CommitOperationAdd,
constants,
create_branch, create_branch,
create_commit, create_commit,
create_repo, create_repo,
get_hf_file_metadata, get_hf_file_metadata,
hf_hub_download, hf_hub_download,
hf_hub_url, hf_hub_url,
try_to_load_from_cache,
) )
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import ( from huggingface_hub.utils import (
...@@ -49,7 +52,9 @@ from huggingface_hub.utils import ( ...@@ -49,7 +52,9 @@ from huggingface_hub.utils import (
RevisionNotFoundError, RevisionNotFoundError,
build_hf_headers, build_hf_headers,
hf_raise_for_status, hf_raise_for_status,
send_telemetry,
) )
from huggingface_hub.utils._deprecation import _deprecate_method
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from . import __version__, logging from . import __version__, logging
...@@ -75,17 +80,25 @@ def is_offline_mode(): ...@@ -75,17 +80,25 @@ def is_offline_mode():
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch")) torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
default_cache_path = constants.default_cache_path
old_default_cache_path = os.path.join(torch_cache_home, "transformers") old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser( # Determine default cache directory. Lots of legacy environment variables to ensure backward compatibility.
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface")) # The best way to set the cache path is with the environment variable HF_HOME. For more details, checkout this
) # documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
default_cache_path = os.path.join(hf_cache_home, "hub") #
# In code, use `HF_HUB_CACHE` as the default cache path. This variable is set by the library and is guaranteed
# to be set to the right value.
#
# TODO: clean this for v5?
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
# Onetime move from the old location to the new one if no ENV variable has been set. # Onetime move from the old location to the new one if no ENV variable has been set.
if ( if (
os.path.isdir(old_default_cache_path) os.path.isdir(old_default_cache_path)
and not os.path.isdir(default_cache_path) and not os.path.isdir(constants.HF_HUB_CACHE)
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
and "TRANSFORMERS_CACHE" not in os.environ and "TRANSFORMERS_CACHE" not in os.environ
...@@ -97,16 +110,26 @@ if ( ...@@ -97,16 +110,26 @@ if (
" '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should" " '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
" only see this message once." " only see this message once."
) )
shutil.move(old_default_cache_path, default_cache_path) shutil.move(old_default_cache_path, constants.HF_HUB_CACHE)
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path) HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE)
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules" TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", constants.HF_HUB_DISABLE_TELEMETRY) in ENV_VARS_TRUE_VALUES
# Add deprecation warning for old environment variables.
for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"):
if os.getenv(key) is not None:
warnings.warn(
f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.",
FutureWarning,
)
if os.getenv("DISABLE_TELEMETRY") is not None:
warnings.warn(
"Using `DISABLE_TELEMETRY` is deprecated and will be removed in v5 of Transformers. Use `HF_HUB_DISABLE_TELEMETRY` instead.",
FutureWarning,
)
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert" S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co" CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
...@@ -126,15 +149,16 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R ...@@ -126,15 +149,16 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples" HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()
def is_remote_url(url_or_filename): def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename) parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https") return parsed.scheme in ("http", "https")
# TODO: remove this once fully deprecated
# TODO? remove from './examples/research_projects/lxmert/utils.py' as well
# TODO? remove from './examples/research_projects/visual_bert/utils.py' as well
@_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.")
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]: def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
""" """
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url, Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
...@@ -219,7 +243,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str: ...@@ -219,7 +243,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua return ua
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]): def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]:
""" """
Extracts the commit hash from a resolved filename toward a cache file. Extracts the commit hash from a resolved filename toward a cache file.
""" """
...@@ -233,73 +257,6 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] ...@@ -233,73 +257,6 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
def try_to_load_from_cache(
repo_id: str,
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
) -> Optional[str]:
"""
Explores the cache to return the latest cached file for a given revision if found.
This function will not raise any exception if the file in not cached.
Args:
cache_dir (`str` or `os.PathLike`):
The folder where the cached files lie.
repo_id (`str`):
The ID of the repo on huggingface.co.
filename (`str`):
The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
repo_type (`str`, *optional*):
The type of the repo.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
Will return `None` if the file was not cached. Otherwise:
- The exact path to the cached file if it's found in the cache
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
"""
if revision is None:
revision = "main"
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
object_id = repo_id.replace("/", "--")
if repo_type is None:
repo_type = "model"
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
if not os.path.isdir(repo_cache):
# No cache for this model
return None
for subfolder in ["refs", "snapshots"]:
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
return None
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(repo_cache, "refs", revision)) as f:
revision = f.read()
if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
return _CACHED_NO_EXIST
cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
return cached_file if os.path.isfile(cached_file) else None
def cached_file( def cached_file(
path_or_repo_id: Union[str, os.PathLike], path_or_repo_id: Union[str, os.PathLike],
filename: str, filename: str,
...@@ -317,7 +274,7 @@ def cached_file( ...@@ -317,7 +274,7 @@ def cached_file(
_raise_exceptions_for_connection_errors: bool = True, _raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None, _commit_hash: Optional[str] = None,
**deprecated_kwargs, **deprecated_kwargs,
): ) -> Optional[str]:
""" """
Tries to locate a file in a local folder and repo, downloads and cache it if necessary. Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
...@@ -369,7 +326,8 @@ def cached_file( ...@@ -369,7 +326,8 @@ def cached_file(
```python ```python
# Download a model weight from the Hub and cache it. # Download a model weight from the Hub and cache it.
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin") model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```""" ```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None) use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None: if use_auth_token is not None:
warnings.warn( warnings.warn(
...@@ -499,6 +457,10 @@ def cached_file( ...@@ -499,6 +457,10 @@ def cached_file(
return resolved_file return resolved_file
# TODO: deprecate `get_file_from_repo` or document it differently?
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
def get_file_from_repo( def get_file_from_repo(
path_or_repo: Union[str, os.PathLike], path_or_repo: Union[str, os.PathLike],
filename: str, filename: str,
...@@ -564,7 +526,8 @@ def get_file_from_repo( ...@@ -564,7 +526,8 @@ def get_file_from_repo(
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json") tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
# This model does not have a tokenizer config so the result will be None. # This model does not have a tokenizer config so the result will be None.
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json") tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
```""" ```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None) use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None: if use_auth_token is not None:
warnings.warn( warnings.warn(
...@@ -609,10 +572,11 @@ def download_url(url, proxies=None): ...@@ -609,10 +572,11 @@ def download_url(url, proxies=None):
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in" f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note" " v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
" that this is not compatible with the caching system (your file will be downloaded at each execution) or" " that this is not compatible with the caching system (your file will be downloaded at each execution) or"
" multiple processes (each process will download the file in a different temporary file)." " multiple processes (each process will download the file in a different temporary file).",
FutureWarning,
) )
tmp_file = tempfile.mkstemp()[1] tmp_fd, tmp_file = tempfile.mkstemp()
with open(tmp_file, "wb") as f: with os.fdopen(tmp_fd, "wb") as f:
http_get(url, f, proxies=proxies) http_get(url, f, proxies=proxies)
return tmp_file return tmp_file
...@@ -947,13 +911,10 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"): ...@@ -947,13 +911,10 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
script_name = script_name.replace("_no_trainer", "") script_name = script_name.replace("_no_trainer", "")
data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}" data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
headers = {"user-agent": http_user_agent(data)} # Send telemetry in the background
try: send_telemetry(
r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers) topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
r.raise_for_status() )
except Exception:
# We don't want to error in case of connection errors of any kind.
pass
def convert_file_size_to_int(size: Union[int, str]): def convert_file_size_to_int(size: Union[int, str]):
...@@ -1258,7 +1219,7 @@ if cache_version < 1 and cache_is_not_empty: ...@@ -1258,7 +1219,7 @@ if cache_version < 1 and cache_is_not_empty:
"`transformers.utils.move_cache()`." "`transformers.utils.move_cache()`."
) )
try: try:
if TRANSFORMERS_CACHE != default_cache_path: if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
# Users set some env variable to customize cache storage # Users set some env variable to customize cache storage
move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE) move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment