Unverified Commit d2a980ec authored by Lucain's avatar Lucain Committed by GitHub
Browse files

Harmonize HF environment variables + other cleaning (#27564)

* Harmonize HF environment variables + other cleaning

* backward compat

* switch from HUGGINGFACE_HUB_CACHE to HF_HUB_CACHE

* revert
parent 7f043738
......@@ -25,6 +25,8 @@ import warnings
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from huggingface_hub import try_to_load_from_cache
from .utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
......@@ -32,7 +34,6 @@ from .utils import (
extract_commit_hash,
is_offline_mode,
logging,
try_to_load_from_cache,
)
......
......@@ -31,13 +31,16 @@ from uuid import uuid4
import huggingface_hub
import requests
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
constants,
create_branch,
create_commit,
create_repo,
get_hf_file_metadata,
hf_hub_download,
hf_hub_url,
try_to_load_from_cache,
)
from huggingface_hub.file_download import REGEX_COMMIT_HASH, http_get
from huggingface_hub.utils import (
......@@ -49,7 +52,9 @@ from huggingface_hub.utils import (
RevisionNotFoundError,
build_hf_headers,
hf_raise_for_status,
send_telemetry,
)
from huggingface_hub.utils._deprecation import _deprecate_method
from requests.exceptions import HTTPError
from . import __version__, logging
......@@ -75,17 +80,25 @@ def is_offline_mode():
torch_cache_home = os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
default_cache_path = constants.default_cache_path
old_default_cache_path = os.path.join(torch_cache_home, "transformers")
# New default cache, shared with the Datasets library
hf_cache_home = os.path.expanduser(
os.getenv("HF_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"))
)
default_cache_path = os.path.join(hf_cache_home, "hub")
# Determine default cache directory. Lots of legacy environment variables to ensure backward compatibility.
# The best way to set the cache path is with the environment variable HF_HOME. For more details, checkout this
# documentation page: https://huggingface.co/docs/huggingface_hub/package_reference/environment_variables.
#
# In code, use `HF_HUB_CACHE` as the default cache path. This variable is set by the library and is guaranteed
# to be set to the right value.
#
# TODO: clean this for v5?
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", constants.HF_HUB_CACHE)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
# Onetime move from the old location to the new one if no ENV variable has been set.
if (
os.path.isdir(old_default_cache_path)
and not os.path.isdir(default_cache_path)
and not os.path.isdir(constants.HF_HUB_CACHE)
and "PYTORCH_PRETRAINED_BERT_CACHE" not in os.environ
and "PYTORCH_TRANSFORMERS_CACHE" not in os.environ
and "TRANSFORMERS_CACHE" not in os.environ
......@@ -97,16 +110,26 @@ if (
" '~/.cache/huggingface/transformers' to avoid redownloading models you have already in the cache. You should"
" only see this message once."
)
shutil.move(old_default_cache_path, default_cache_path)
shutil.move(old_default_cache_path, constants.HF_HUB_CACHE)
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
HUGGINGFACE_HUB_CACHE = os.getenv("HUGGINGFACE_HUB_CACHE", PYTORCH_TRANSFORMERS_CACHE)
TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", HUGGINGFACE_HUB_CACHE)
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(constants.HF_HOME, "modules"))
TRANSFORMERS_DYNAMIC_MODULE_NAME = "transformers_modules"
SESSION_ID = uuid4().hex
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", False) in ENV_VARS_TRUE_VALUES
DISABLE_TELEMETRY = os.getenv("DISABLE_TELEMETRY", constants.HF_HUB_DISABLE_TELEMETRY) in ENV_VARS_TRUE_VALUES
# Add deprecation warning for old environment variables.
for key in ("PYTORCH_PRETRAINED_BERT_CACHE", "PYTORCH_TRANSFORMERS_CACHE", "TRANSFORMERS_CACHE"):
if os.getenv(key) is not None:
warnings.warn(
f"Using `{key}` is deprecated and will be removed in v5 of Transformers. Use `HF_HOME` instead.",
FutureWarning,
)
if os.getenv("DISABLE_TELEMETRY") is not None:
warnings.warn(
"Using `DISABLE_TELEMETRY` is deprecated and will be removed in v5 of Transformers. Use `HF_HUB_DISABLE_TELEMETRY` instead.",
FutureWarning,
)
S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
......@@ -126,15 +149,16 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
# TODO: remove this once fully deprecated
# TODO? remove from './examples/research_projects/lxmert/utils.py' as well
# TODO? remove from './examples/research_projects/visual_bert/utils.py' as well
@_deprecate_method(version="4.39.0", message="This method is outdated and does not support the new cache system.")
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
......@@ -219,7 +243,7 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]):
def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]) -> Optional[str]:
"""
Extracts the commit hash from a resolved filename toward a cache file.
"""
......@@ -233,73 +257,6 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
return commit_hash if REGEX_COMMIT_HASH.match(commit_hash) else None
def try_to_load_from_cache(
repo_id: str,
filename: str,
cache_dir: Union[str, Path, None] = None,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
) -> Optional[str]:
"""
Explores the cache to return the latest cached file for a given revision if found.
This function will not raise any exception if the file in not cached.
Args:
cache_dir (`str` or `os.PathLike`):
The folder where the cached files lie.
repo_id (`str`):
The ID of the repo on huggingface.co.
filename (`str`):
The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
repo_type (`str`, *optional*):
The type of the repo.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
Will return `None` if the file was not cached. Otherwise:
- The exact path to the cached file if it's found in the cache
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
"""
if revision is None:
revision = "main"
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
object_id = repo_id.replace("/", "--")
if repo_type is None:
repo_type = "model"
repo_cache = os.path.join(cache_dir, f"{repo_type}s--{object_id}")
if not os.path.isdir(repo_cache):
# No cache for this model
return None
for subfolder in ["refs", "snapshots"]:
if not os.path.isdir(os.path.join(repo_cache, subfolder)):
return None
# Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(repo_cache, "refs"))
if revision in cached_refs:
with open(os.path.join(repo_cache, "refs", revision)) as f:
revision = f.read()
if os.path.isfile(os.path.join(repo_cache, ".no_exist", revision, filename)):
return _CACHED_NO_EXIST
cached_shas = os.listdir(os.path.join(repo_cache, "snapshots"))
if revision not in cached_shas:
# No cache for this revision and we won't try to return a random revision
return None
cached_file = os.path.join(repo_cache, "snapshots", revision, filename)
return cached_file if os.path.isfile(cached_file) else None
def cached_file(
path_or_repo_id: Union[str, os.PathLike],
filename: str,
......@@ -317,7 +274,7 @@ def cached_file(
_raise_exceptions_for_connection_errors: bool = True,
_commit_hash: Optional[str] = None,
**deprecated_kwargs,
):
) -> Optional[str]:
"""
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
......@@ -369,7 +326,8 @@ def cached_file(
```python
# Download a model weight from the Hub and cache it.
model_weights_file = cached_file("bert-base-uncased", "pytorch_model.bin")
```"""
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
......@@ -499,6 +457,10 @@ def cached_file(
return resolved_file
# TODO: deprecate `get_file_from_repo` or document it differently?
# Docstring is exactly the same as `cached_repo` but behavior is slightly different. If file is missing or if
# there is a connection error, `cached_repo` will return None while `get_file_from_repo` will raise an error.
# IMO we should keep only 1 method and have a single `raise_error` argument (to be discussed).
def get_file_from_repo(
path_or_repo: Union[str, os.PathLike],
filename: str,
......@@ -564,7 +526,8 @@ def get_file_from_repo(
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
# This model does not have a tokenizer config so the result will be None.
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
```"""
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
if use_auth_token is not None:
warnings.warn(
......@@ -609,10 +572,11 @@ def download_url(url, proxies=None):
f"Using `from_pretrained` with the url of a file (here {url}) is deprecated and won't be possible anymore in"
" v5 of Transformers. You should host your file on the Hub (hf.co) instead and use the repository ID. Note"
" that this is not compatible with the caching system (your file will be downloaded at each execution) or"
" multiple processes (each process will download the file in a different temporary file)."
" multiple processes (each process will download the file in a different temporary file).",
FutureWarning,
)
tmp_file = tempfile.mkstemp()[1]
with open(tmp_file, "wb") as f:
tmp_fd, tmp_file = tempfile.mkstemp()
with os.fdopen(tmp_fd, "wb") as f:
http_get(url, f, proxies=proxies)
return tmp_file
......@@ -947,13 +911,10 @@ def send_example_telemetry(example_name, *example_args, framework="pytorch"):
script_name = script_name.replace("_no_trainer", "")
data["dataset_name"] = f"{script_name}-{args_as_dict['task_name']}"
headers = {"user-agent": http_user_agent(data)}
try:
r = requests.head(HUGGINGFACE_CO_EXAMPLES_TELEMETRY, headers=headers)
r.raise_for_status()
except Exception:
# We don't want to error in case of connection errors of any kind.
pass
# Send telemetry in the background
send_telemetry(
topic="examples", library_name="transformers", library_version=__version__, user_agent=http_user_agent(data)
)
def convert_file_size_to_int(size: Union[int, str]):
......@@ -1258,7 +1219,7 @@ if cache_version < 1 and cache_is_not_empty:
"`transformers.utils.move_cache()`."
)
try:
if TRANSFORMERS_CACHE != default_cache_path:
if TRANSFORMERS_CACHE != constants.HF_HUB_CACHE:
# Users set some env variable to customize cache storage
move_cache(TRANSFORMERS_CACHE, TRANSFORMERS_CACHE)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment