Unverified Commit 377cdded authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Clean up hub (#18497)

* Clean up utils.hub

* Remove imports

* More fixes

* Last fix
parent a4562552
......@@ -441,7 +441,6 @@ _import_structure = {
"TensorType",
"add_end_docstrings",
"add_start_docstrings",
"cached_path",
"is_apex_available",
"is_datasets_available",
"is_faiss_available",
......@@ -3214,7 +3213,6 @@ if TYPE_CHECKING:
TensorType,
add_end_docstrings,
add_start_docstrings,
cached_path,
is_apex_available,
is_datasets_available,
is_faiss_available,
......
......@@ -38,7 +38,6 @@ from . import (
T5_PRETRAINED_CONFIG_ARCHIVE_MAP,
TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP,
WAV_2_VEC_2_PRETRAINED_CONFIG_ARCHIVE_MAP,
WEIGHTS_NAME,
XLM_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP,
XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP,
......@@ -91,11 +90,10 @@ from . import (
XLMConfig,
XLMRobertaConfig,
XLNetConfig,
cached_path,
is_torch_available,
load_pytorch_checkpoint_in_tf2_model,
)
from .utils import hf_bucket_url, logging
from .utils import CONFIG_NAME, WEIGHTS_NAME, cached_file, logging
if is_torch_available():
......@@ -311,7 +309,7 @@ def convert_pt_checkpoint_to_tf(
# Initialise TF model
if config_file in aws_config_map:
config_file = cached_path(aws_config_map[config_file], force_download=not use_cached_models)
config_file = cached_file(config_file, CONFIG_NAME, force_download=not use_cached_models)
config = config_class.from_json_file(config_file)
config.output_hidden_states = True
config.output_attentions = True
......@@ -320,8 +318,9 @@ def convert_pt_checkpoint_to_tf(
# Load weights from tf checkpoint
if pytorch_checkpoint_path in aws_config_map.keys():
pytorch_checkpoint_url = hf_bucket_url(pytorch_checkpoint_path, filename=WEIGHTS_NAME)
pytorch_checkpoint_path = cached_path(pytorch_checkpoint_url, force_download=not use_cached_models)
pytorch_checkpoint_path = cached_file(
pytorch_checkpoint_path, WEIGHTS_NAME, force_download=not use_cached_models
)
# Load PyTorch checkpoint in tf2 model:
tf_model = load_pytorch_checkpoint_in_tf2_model(tf_model, pytorch_checkpoint_path)
......@@ -395,14 +394,14 @@ def convert_all_pt_checkpoints_to_tf(
print("-" * 100)
if config_shortcut_name in aws_config_map:
config_file = cached_path(aws_config_map[config_shortcut_name], force_download=not use_cached_models)
config_file = cached_file(config_shortcut_name, CONFIG_NAME, force_download=not use_cached_models)
else:
config_file = cached_path(config_shortcut_name, force_download=not use_cached_models)
config_file = config_shortcut_name
if model_shortcut_name in aws_model_maps:
model_file = cached_path(aws_model_maps[model_shortcut_name], force_download=not use_cached_models)
model_file = cached_file(model_shortcut_name, WEIGHTS_NAME, force_download=not use_cached_models)
else:
model_file = cached_path(model_shortcut_name, force_download=not use_cached_models)
model_file = model_shortcut_name
if os.path.isfile(model_shortcut_name):
model_shortcut_name = "converted_model"
......
......@@ -24,14 +24,7 @@ from typing import Dict, Optional, Union
from huggingface_hub import HfFolder, model_info
from .utils import (
HF_MODULES_CACHE,
TRANSFORMERS_DYNAMIC_MODULE_NAME,
cached_path,
hf_bucket_url,
is_offline_mode,
logging,
)
from .utils import HF_MODULES_CACHE, TRANSFORMERS_DYNAMIC_MODULE_NAME, cached_file, is_offline_mode, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -219,18 +212,15 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
submodule = "local"
else:
module_file_or_url = hf_bucket_url(
pretrained_model_name_or_path, filename=module_file, revision=revision, mirror=None
)
submodule = pretrained_model_name_or_path.replace("/", os.path.sep)
try:
# Load from URL or cache if already cached
resolved_module_file = cached_path(
module_file_or_url,
resolved_module_file = cached_file(
pretrained_model_name_or_path,
module_file,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
......
......@@ -69,20 +69,14 @@ from .utils import (
add_end_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
cached_path,
cached_property,
copy_func,
default_cache_path,
define_sagemaker_information,
filename_to_url,
get_cached_models,
get_file_from_repo,
get_from_cache,
get_full_repo_name,
get_list_of_files,
has_file,
hf_bucket_url,
http_get,
http_user_agent,
is_apex_available,
is_coloredlogs_available,
......@@ -94,7 +88,6 @@ from .utils import (
is_in_notebook,
is_ipex_available,
is_librosa_available,
is_local_clone,
is_offline_mode,
is_onnx_available,
is_pandas_available,
......@@ -105,7 +98,6 @@ from .utils import (
is_pyctcdecode_available,
is_pytesseract_available,
is_pytorch_quantization_available,
is_remote_url,
is_rjieba_available,
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
......@@ -141,5 +133,4 @@ from .utils import (
torch_only_method,
torch_required,
torch_version,
url_to_filename,
)
......@@ -43,15 +43,10 @@ from .models.auto.modeling_auto import (
)
from .training_args import ParallelMode
from .utils import (
CONFIG_NAME,
MODEL_CARD_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
cached_path,
hf_bucket_url,
cached_file,
is_datasets_available,
is_offline_mode,
is_remote_url,
is_tf_available,
is_tokenizers_available,
is_torch_available,
......@@ -153,11 +148,6 @@ class ModelCard:
A dictionary of proxy servers to use by protocol or endpoint, e.g.: {'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}. The proxies are used on each request.
find_from_standard_name: (*optional*) boolean, default True:
If the pretrained_model_name_or_path ends with our standard model or config filenames, replace them
with our standard modelcard filename. Can be used to directly feed a model/config url and access the
colocated modelcard.
return_unused_kwargs: (*optional*) bool:
- If False, then this function returns just the final model card object.
......@@ -168,21 +158,15 @@ class ModelCard:
Examples:
```python
modelcard = ModelCard.from_pretrained(
"bert-base-uncased"
) # Download model card from huggingface.co and cache.
modelcard = ModelCard.from_pretrained(
"./test/saved_model/"
) # E.g. model card was saved using *save_pretrained('./test/saved_model/')*
# Download model card from huggingface.co and cache.
modelcard = ModelCard.from_pretrained("bert-base-uncased")
# Model card was saved using *save_pretrained('./test/saved_model/')*
modelcard = ModelCard.from_pretrained("./test/saved_model/")
modelcard = ModelCard.from_pretrained("./test/saved_model/modelcard.json")
modelcard = ModelCard.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
```"""
# This imports every model so let's do it dynamically here.
from transformers.models.auto.configuration_auto import ALL_PRETRAINED_CONFIG_ARCHIVE_MAP
cache_dir = kwargs.pop("cache_dir", None)
proxies = kwargs.pop("proxies", None)
find_from_standard_name = kwargs.pop("find_from_standard_name", True)
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
from_pipeline = kwargs.pop("_from_pipeline", None)
......@@ -190,37 +174,30 @@ class ModelCard:
if from_pipeline is not None:
user_agent["using_pipeline"] = from_pipeline
if pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
# For simplicity we use the same pretrained url than the configuration files
# but with a different suffix (modelcard.json). This suffix is replaced below.
model_card_file = ALL_PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
elif os.path.isdir(pretrained_model_name_or_path):
model_card_file = os.path.join(pretrained_model_name_or_path, MODEL_CARD_NAME)
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
model_card_file = pretrained_model_name_or_path
is_local = os.path.isdir(pretrained_model_name_or_path)
if os.path.isfile(pretrained_model_name_or_path):
resolved_model_card_file = pretrained_model_name_or_path
is_local = True
else:
model_card_file = hf_bucket_url(pretrained_model_name_or_path, filename=MODEL_CARD_NAME, mirror=None)
if find_from_standard_name or pretrained_model_name_or_path in ALL_PRETRAINED_CONFIG_ARCHIVE_MAP:
model_card_file = model_card_file.replace(CONFIG_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(WEIGHTS_NAME, MODEL_CARD_NAME)
model_card_file = model_card_file.replace(TF2_WEIGHTS_NAME, MODEL_CARD_NAME)
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_path(
model_card_file, cache_dir=cache_dir, proxies=proxies, user_agent=user_agent
)
if resolved_model_card_file == model_card_file:
logger.info(f"loading model card file {model_card_file}")
else:
logger.info(f"loading model card file {model_card_file} from cache at {resolved_model_card_file}")
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)
try:
# Load from URL or cache if already cached
resolved_model_card_file = cached_file(
pretrained_model_name_or_path,
filename=MODEL_CARD_NAME,
cache_dir=cache_dir,
proxies=proxies,
user_agent=user_agent,
)
if is_local:
logger.info(f"loading model card file {resolved_model_card_file}")
else:
logger.info(f"loading model card file {MODEL_CARD_NAME} from cache at {resolved_model_card_file}")
# Load model card
modelcard = cls.from_json_file(resolved_model_card_file)
except (EnvironmentError, json.JSONDecodeError):
# We fall back on creating an empty model card
modelcard = cls()
except (EnvironmentError, json.JSONDecodeError):
# We fall back on creating an empty model card
modelcard = cls()
# Update model card with kwargs if needed
to_remove = []
......
......@@ -2156,7 +2156,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
mirror = kwargs.pop("mirror", None)
_ = kwargs.pop("mirror", None)
load_weight_prefix = kwargs.pop("load_weight_prefix", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
......@@ -2270,7 +2270,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# message.
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
......@@ -2321,7 +2320,6 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
mirror=mirror,
)
config.name_or_path = pretrained_model_name_or_path
......
......@@ -1784,7 +1784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
trust_remote_code = kwargs.pop("trust_remote_code", None)
mirror = kwargs.pop("mirror", None)
_ = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
_fast_init = kwargs.pop("_fast_init", True)
......@@ -1955,7 +1955,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# message.
has_file_kwargs = {
"revision": revision,
"mirror": mirror,
"proxies": proxies,
"use_auth_token": use_auth_token,
}
......@@ -2012,7 +2011,6 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
mirror=mirror,
subfolder=subfolder,
)
......
......@@ -23,7 +23,7 @@ import numpy as np
from ...tokenization_utils import PreTrainedTokenizer
from ...tokenization_utils_base import BatchEncoding
from ...utils import cached_path, is_datasets_available, is_faiss_available, is_remote_url, logging, requires_backends
from ...utils import cached_file, is_datasets_available, is_faiss_available, logging, requires_backends
from .configuration_rag import RagConfig
from .tokenization_rag import RagTokenizer
......@@ -111,22 +111,21 @@ class LegacyIndex(Index):
self._index_initialized = False
def _resolve_path(self, index_path, filename):
assert os.path.isdir(index_path) or is_remote_url(index_path), "Please specify a valid `index_path`."
archive_file = os.path.join(index_path, filename)
is_local = os.path.isdir(index_path)
try:
# Load from URL or cache if already cached
resolved_archive_file = cached_path(archive_file)
resolved_archive_file = cached_file(index_path, filename)
except EnvironmentError:
msg = (
f"Can't load '{archive_file}'. Make sure that:\n\n"
f"Can't load '{filename}'. Make sure that:\n\n"
f"- '{index_path}' is a correct remote path to a directory containing a file named {filename}\n\n"
f"- or '{index_path}' is the correct path to a directory containing a file named {filename}.\n\n"
)
raise EnvironmentError(msg)
if resolved_archive_file == archive_file:
logger.info(f"loading file {archive_file}")
if is_local:
logger.info(f"loading file {resolved_archive_file}")
else:
logger.info(f"loading file {archive_file} from cache at {resolved_archive_file}")
logger.info(f"loading file {filename} from cache at {resolved_archive_file}")
return resolved_archive_file
def _load_passages(self):
......
......@@ -29,7 +29,7 @@ import numpy as np
from ...tokenization_utils import PreTrainedTokenizer
from ...utils import (
cached_path,
cached_file,
is_sacremoses_available,
is_torch_available,
logging,
......@@ -681,24 +681,21 @@ class TransfoXLCorpus(object):
Instantiate a pre-processed corpus.
"""
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
is_local = os.path.isdir(pretrained_model_name_or_path)
# redirect to the cache, if necessary
try:
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
resolved_corpus_file = cached_file(pretrained_model_name_or_path, CORPUS_NAME, cache_dir=cache_dir)
except EnvironmentError:
logger.error(
f"Corpus '{pretrained_model_name_or_path}' was not found in corpus list"
f" ({', '.join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys())}. We assumed '{pretrained_model_name_or_path}'"
f" was a path or url but couldn't find files {corpus_file} at this path or url."
f" was a path or url but couldn't find files {CORPUS_NAME} at this path or url."
)
return None
if resolved_corpus_file == corpus_file:
logger.info(f"loading corpus file {corpus_file}")
if is_local:
logger.info(f"loading corpus file {resolved_corpus_file}")
else:
logger.info(f"loading corpus file {corpus_file} from cache at {resolved_corpus_file}")
logger.info(f"loading corpus file {CORPUS_NAME} from cache at {resolved_corpus_file}")
# Instantiate tokenizer.
corpus = cls(*inputs, **kwargs)
......
......@@ -25,6 +25,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from numpy import isin
from huggingface_hub.file_download import http_get
from ..configuration_utils import PretrainedConfig
from ..dynamic_module_utils import get_class_from_dynamic_module
from ..feature_extraction_utils import PreTrainedFeatureExtractor
......@@ -33,7 +35,7 @@ from ..models.auto.feature_extraction_auto import FEATURE_EXTRACTOR_MAPPING, Aut
from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
from ..tokenization_utils import PreTrainedTokenizer
from ..tokenization_utils_fast import PreTrainedTokenizerFast
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, http_get, is_tf_available, is_torch_available, logging
from ..utils import HUGGINGFACE_CO_RESOLVE_ENDPOINT, is_tf_available, is_torch_available, logging
from .audio_classification import AudioClassificationPipeline
from .automatic_speech_recognition import AutomaticSpeechRecognitionPipeline
from .base import (
......
......@@ -61,25 +61,16 @@ from .hub import (
RepositoryNotFoundError,
RevisionNotFoundError,
cached_file,
cached_path,
default_cache_path,
define_sagemaker_information,
filename_to_url,
get_cached_models,
get_file_from_repo,
get_from_cache,
get_full_repo_name,
get_list_of_files,
has_file,
hf_bucket_url,
http_get,
http_user_agent,
is_local_clone,
is_offline_mode,
is_remote_url,
move_cache,
send_example_telemetry,
url_to_filename,
)
from .import_utils import (
ENV_VARS_TRUE_AND_AUTO_VALUES,
......
......@@ -14,44 +14,32 @@
"""
Hub utilities: utilities related to download and cache models
"""
import copy
import fnmatch
import io
import json
import os
import re
import shutil
import subprocess
import sys
import tarfile
import tempfile
import traceback
import warnings
from contextlib import contextmanager
from functools import partial
from hashlib import sha256
from pathlib import Path
from typing import BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from typing import Dict, List, Optional, Tuple, Union
from uuid import uuid4
from zipfile import ZipFile, is_zipfile
import huggingface_hub
import requests
from filelock import FileLock
from huggingface_hub import (
CommitOperationAdd,
HfFolder,
create_commit,
create_repo,
hf_hub_download,
list_repo_files,
hf_hub_url,
whoami,
)
from huggingface_hub.constants import HUGGINGFACE_HEADER_X_LINKED_ETAG, HUGGINGFACE_HEADER_X_REPO_COMMIT
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests.exceptions import HTTPError
from requests.models import Response
from transformers.utils.logging import tqdm
from . import __version__, logging
......@@ -128,93 +116,6 @@ HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
def is_remote_url(url_or_filename):
parsed = urlparse(url_or_filename)
return parsed.scheme in ("http", "https")
def hf_bucket_url(
model_id: str, filename: str, subfolder: Optional[str] = None, revision: Optional[str] = None, mirror=None
) -> str:
"""
Resolve a model identifier, a file name, and an optional revision id, to a huggingface.co-hosted url, redirecting
to Cloudfront (a Content Delivery Network, or CDN) for large files.
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our
bandwidth costs).
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here
because we migrated to a git-based versioning system on huggingface.co, so we now store the files on S3/Cloudfront
in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache
can't ever be stale.
In terms of client-side caching from this library, we base our caching on the objects' ETag. An object' ETag is:
its sha1 if stored in git, or its sha256 if stored in git-lfs. Files cached locally from transformers before v3.5.0
are not shared with those new files, because the cached file's name contains a hash of the url (which changed).
"""
if subfolder is not None:
filename = f"{subfolder}/{filename}"
if mirror:
if mirror in ["tuna", "bfsu"]:
raise ValueError("The Tuna and BFSU mirrors are no longer available. Try removing the mirror argument.")
legacy_format = "/" not in model_id
if legacy_format:
return f"{mirror}/{model_id}-{filename}"
else:
return f"{mirror}/{model_id}/{filename}"
if revision is None:
revision = "main"
return HUGGINGFACE_CO_PREFIX.format(model_id=model_id, revision=revision, filename=filename)
def url_to_filename(url: str, etag: Optional[str] = None) -> str:
"""
Convert `url` into a hashed filename in a repeatable way. If `etag` is specified, append its hash to the url's,
delimited by a period. If the url ends with .h5 (Keras HDF5 weights) adds '.h5' to the name so that TF 2.0 can
identify it as a HDF5 file (see
https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1380)
"""
url_bytes = url.encode("utf-8")
filename = sha256(url_bytes).hexdigest()
if etag:
etag_bytes = etag.encode("utf-8")
filename += "." + sha256(etag_bytes).hexdigest()
if url.endswith(".h5"):
filename += ".h5"
return filename
def filename_to_url(filename, cache_dir=None):
"""
Return the url and etag (which may be `None`) stored for *filename*. Raise `EnvironmentError` if *filename* or its
stored metadata do not exist.
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cache_path = os.path.join(cache_dir, filename)
if not os.path.exists(cache_path):
raise EnvironmentError(f"file {cache_path} not found")
meta_path = cache_path + ".json"
if not os.path.exists(meta_path):
raise EnvironmentError(f"file {meta_path} not found")
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
return url, etag
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape `(model_url,
......@@ -248,108 +149,6 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
return cached_models
def cached_path(
url_or_filename,
cache_dir=None,
force_download=False,
proxies=None,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
extract_compressed_file=False,
force_extract=False,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
Given something that might be a URL (or might be a local path), determine which. If it's a URL, download the file
and cache it, and return the path to the cached file. If it's already a local path, make sure the file exists and
then return the path
Args:
cache_dir: specify a cache directory to save the file to (overwrite the default cache dir).
force_download: if True, re-download the file even if it's already cached in the cache dir.
resume_download: if True, resume the download if incompletely received file is found.
user_agent: Optional string or dict that will be appended to the user-agent on remote requests.
use_auth_token: Optional string or boolean to use as Bearer token for remote files. If True,
will get token from ~/.huggingface.
extract_compressed_file: if True and the path point to a zip or tar file, extract the compressed
file in a folder along the archive.
force_extract: if True when extract_compressed_file is True and the archive was already extracted,
re-extract the archive and override the folder where it was extracted.
Return:
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(url_or_filename, Path):
url_or_filename = str(url_or_filename)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
if is_offline_mode() and not local_files_only:
logger.info("Offline mode: forcing local_files_only=True")
local_files_only = True
if is_remote_url(url_or_filename):
# URL, so get it from the cache (downloading if necessary)
output_path = get_from_cache(
url_or_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
user_agent=user_agent,
use_auth_token=use_auth_token,
local_files_only=local_files_only,
)
elif os.path.exists(url_or_filename):
# File, and it exists.
output_path = url_or_filename
elif urlparse(url_or_filename).scheme == "":
# File, but it doesn't exist.
raise EnvironmentError(f"file {url_or_filename} not found")
else:
# Something unknown
raise ValueError(f"unable to parse {url_or_filename} as a URL or as a local path")
if extract_compressed_file:
if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
return output_path
# Path where we extract compressed archives
# We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
output_dir, output_file = os.path.split(output_path)
output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
return output_path_extracted
# Prevent parallel extractions
lock_path = output_path + ".lock"
with FileLock(lock_path):
shutil.rmtree(output_path_extracted, ignore_errors=True)
os.makedirs(output_path_extracted)
if is_zipfile(output_path):
with ZipFile(output_path, "r") as zip_file:
zip_file.extractall(output_path_extracted)
zip_file.close()
elif tarfile.is_tarfile(output_path):
tar_file = tarfile.open(output_path)
tar_file.extractall(output_path_extracted)
tar_file.close()
else:
raise EnvironmentError(f"Archive format of {output_path} could not be identified")
return output_path_extracted
return output_path
def define_sagemaker_information():
try:
instance_data = requests.get(os.environ["ECS_CONTAINER_METADATA_URI"]).json()
......@@ -399,234 +198,6 @@ def http_user_agent(user_agent: Union[Dict, str, None] = None) -> str:
return ua
def _raise_for_status(response: Response):
"""
Internal version of `request.raise_for_status()` that will refine a potential HTTPError.
"""
if "X-Error-Code" in response.headers:
error_code = response.headers["X-Error-Code"]
if error_code == "RepoNotFound":
raise RepositoryNotFoundError(f"404 Client Error: Repository Not Found for url: {response.url}")
elif error_code == "EntryNotFound":
raise EntryNotFoundError(f"404 Client Error: Entry Not Found for url: {response.url}")
elif error_code == "RevisionNotFound":
raise RevisionNotFoundError(f"404 Client Error: Revision Not Found for url: {response.url}")
if response.status_code == 401:
# The repo was not found and the user is not Authenticated
raise RepositoryNotFoundError(
f"401 Client Error: Repository not found for url: {response.url}. "
"If the repo is private, make sure you are authenticated."
)
response.raise_for_status()
def http_get(
url: str,
temp_file: BinaryIO,
proxies=None,
resume_size=0,
headers: Optional[Dict[str, str]] = None,
file_name: Optional[str] = None,
):
"""
Download remote file. Do not gobble up errors.
"""
headers = copy.deepcopy(headers)
if resume_size > 0:
headers["Range"] = f"bytes={resume_size}-"
r = requests.get(url, stream=True, proxies=proxies, headers=headers)
_raise_for_status(r)
content_length = r.headers.get("Content-Length")
total = resume_size + int(content_length) if content_length is not None else None
# `tqdm` behavior is determined by `utils.logging.is_progress_bar_enabled()`
# and can be set using `utils.logging.enable/disable_progress_bar()`
progress = tqdm(
unit="B",
unit_scale=True,
unit_divisor=1024,
total=total,
initial=resume_size,
desc=f"Downloading {file_name}" if file_name is not None else "Downloading",
)
for chunk in r.iter_content(chunk_size=1024):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
progress.close()
def get_from_cache(
url: str,
cache_dir=None,
force_download=False,
proxies=None,
etag_timeout=10,
resume_download=False,
user_agent: Union[Dict, str, None] = None,
use_auth_token: Union[bool, str, None] = None,
local_files_only=False,
) -> Optional[str]:
"""
Given a URL, look for the corresponding file in the local cache. If it's not there, download it. Then return the
path to the cached file.
Return:
Local path (string) of file or if networking is off, last version of file cached on disk.
Raises:
In case of non-recoverable file (non-existent or inaccessible url + no cache on disk).
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
os.makedirs(cache_dir, exist_ok=True)
headers = {"user-agent": http_user_agent(user_agent)}
if isinstance(use_auth_token, str):
headers["authorization"] = f"Bearer {use_auth_token}"
elif use_auth_token:
token = HfFolder.get_token()
if token is None:
raise EnvironmentError("You specified use_auth_token=True, but a huggingface token was not found.")
headers["authorization"] = f"Bearer {token}"
url_to_download = url
etag = None
if not local_files_only:
try:
r = requests.head(url, headers=headers, allow_redirects=False, proxies=proxies, timeout=etag_timeout)
_raise_for_status(r)
etag = r.headers.get("X-Linked-Etag") or r.headers.get("ETag")
# We favor a custom header indicating the etag of the linked resource, and
# we fallback to the regular etag header.
# If we don't have any of those, raise an error.
if etag is None:
raise OSError(
"Distant resource does not have an ETag, we won't be able to reliably ensure reproducibility."
)
# In case of a redirect,
# save an extra redirect on the request.get call,
# and ensure we download the exact atomic version even if it changed
# between the HEAD and the GET (unlikely, but hey).
if 300 <= r.status_code <= 399:
url_to_download = r.headers["Location"]
except (
requests.exceptions.SSLError,
requests.exceptions.ProxyError,
RepositoryNotFoundError,
EntryNotFoundError,
RevisionNotFoundError,
):
# Actually raise for those subclasses of ConnectionError
# Also raise the custom errors coming from a non existing repo/branch/file as they are caught later on.
raise
except (HTTPError, requests.exceptions.ConnectionError, requests.exceptions.Timeout):
# Otherwise, our Internet connection is down.
# etag is None
pass
filename = url_to_filename(url, etag)
# get cache path to put the file
cache_path = os.path.join(cache_dir, filename)
# etag is None == we don't have a connection or we passed local_files_only.
# try to get the last downloaded one
if etag is None:
if os.path.exists(cache_path):
return cache_path
else:
matching_files = [
file
for file in fnmatch.filter(os.listdir(cache_dir), filename.split(".")[0] + ".*")
if not file.endswith(".json") and not file.endswith(".lock")
]
if len(matching_files) > 0:
return os.path.join(cache_dir, matching_files[-1])
else:
# If files cannot be found and local_files_only=True,
# the models might've been found if local_files_only=False
# Notify the user about that
if local_files_only:
fname = url.split("/")[-1]
raise EntryNotFoundError(
f"Cannot find the requested file ({fname}) in the cached path and outgoing traffic has been"
" disabled. To enable model look-ups and downloads online, set 'local_files_only'"
" to False."
)
else:
raise ValueError(
"Connection error, and we cannot find the requested files in the cached path."
" Please try again or make sure your Internet connection is on."
)
# From now on, etag is not None.
if os.path.exists(cache_path) and not force_download:
return cache_path
# Prevent parallel downloads of the same file with a lock.
lock_path = cache_path + ".lock"
with FileLock(lock_path):
# If the download just completed while the lock was activated.
if os.path.exists(cache_path) and not force_download:
# Even if returning early like here, the lock will be released.
return cache_path
if resume_download:
incomplete_path = cache_path + ".incomplete"
@contextmanager
def _resumable_file_manager() -> "io.BufferedWriter":
with open(incomplete_path, "ab") as f:
yield f
temp_file_manager = _resumable_file_manager
if os.path.exists(incomplete_path):
resume_size = os.stat(incomplete_path).st_size
else:
resume_size = 0
else:
temp_file_manager = partial(tempfile.NamedTemporaryFile, mode="wb", dir=cache_dir, delete=False)
resume_size = 0
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
with temp_file_manager() as temp_file:
logger.info(f"{url} not found in cache or force_download set to True, downloading to {temp_file.name}")
# The url_to_download might be messy, so we extract the file name from the original url.
file_name = url.split("/")[-1]
http_get(
url_to_download,
temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
file_name=file_name,
)
logger.info(f"storing {url} in cache at {cache_path}")
os.replace(temp_file.name, cache_path)
# NamedTemporaryFile creates a file with hardwired 0600 perms (ignoring umask), so fixing it.
umask = os.umask(0o666)
os.umask(umask)
os.chmod(cache_path, 0o666 & ~umask)
logger.info(f"creating metadata file for {cache_path}")
meta = {"url": url, "etag": etag}
meta_path = cache_path + ".json"
with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
return cache_path
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
"""
Explores the cache to return the latest cached file for a given revision.
......@@ -919,7 +490,6 @@ def has_file(
path_or_repo: Union[str, os.PathLike],
filename: str,
revision: Optional[str] = None,
mirror: Optional[str] = None,
proxies: Optional[Dict[str, str]] = None,
use_auth_token: Optional[Union[bool, str]] = None,
):
......@@ -936,7 +506,7 @@ def has_file(
if os.path.isdir(path_or_repo):
return os.path.isfile(os.path.join(path_or_repo, filename))
url = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=mirror)
url = hf_hub_url(path_or_repo, filename=filename, revision=revision)
headers = {"user-agent": http_user_agent()}
if isinstance(use_auth_token, str):
......@@ -965,89 +535,6 @@ def has_file(
return False
def get_list_of_files(
path_or_repo: Union[str, os.PathLike],
revision: Optional[str] = None,
use_auth_token: Optional[Union[bool, str]] = None,
local_files_only: bool = False,
) -> List[str]:
"""
Gets the list of files inside `path_or_repo`.
Args:
path_or_repo (`str` or `os.PathLike`):
Can be either the id of a repo on huggingface.co or a path to a *directory*.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
local_files_only (`bool`, *optional*, defaults to `False`):
Whether or not to only rely on local files and not to attempt to download any files.
<Tip warning={true}>
This API is not optimized, so calling it a lot may result in connection errors.
</Tip>
Returns:
`List[str]`: The list of files available in `path_or_repo`.
"""
path_or_repo = str(path_or_repo)
# If path_or_repo is a folder, we just return what is inside (subdirectories included).
if os.path.isdir(path_or_repo):
list_of_files = []
for path, dir_names, file_names in os.walk(path_or_repo):
list_of_files.extend([os.path.join(path, f) for f in file_names])
return list_of_files
# Can't grab the files if we are on offline mode.
if is_offline_mode() or local_files_only:
return []
# Otherwise we grab the token and use the list_repo_files method.
if isinstance(use_auth_token, str):
token = use_auth_token
elif use_auth_token is True:
token = HfFolder.get_token()
else:
token = None
try:
return list_repo_files(path_or_repo, revision=revision, token=token)
except HTTPError as e:
raise ValueError(
f"{path_or_repo} is not a local path or a model identifier on the model Hub. Did you make a typo?"
) from e
def is_local_clone(repo_path, repo_url):
"""
Checks if the folder in `repo_path` is a local clone of `repo_url`.
"""
# First double-check that `repo_path` is a git repo
if not os.path.exists(os.path.join(repo_path, ".git")):
return False
test_git = subprocess.run("git branch".split(), cwd=repo_path)
if test_git.returncode != 0:
return False
# Then look at its remotes
remotes = subprocess.run(
"git remote -v".split(),
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=True,
encoding="utf-8",
cwd=repo_path,
).stdout
return repo_url in remotes.split()
class PushToHubMixin:
"""
A Mixin containing the functionality to push a model or tokenizer to the hub.
......@@ -1310,7 +797,6 @@ def get_checkpoint_shard_files(
use_auth_token=None,
user_agent=None,
revision=None,
mirror=None,
subfolder="",
):
"""
......@@ -1343,18 +829,11 @@ def get_checkpoint_shard_files(
# At this stage pretrained_model_name_or_path is a model identifier on the Hub
cached_filenames = []
for shard_filename in shard_filenames:
shard_url = hf_bucket_url(
pretrained_model_name_or_path,
filename=shard_filename,
revision=revision,
mirror=mirror,
subfolder=subfolder if len(subfolder) > 0 else None,
)
try:
# Load from URL
cached_filename = cached_path(
shard_url,
cached_filename = cached_file(
pretrained_model_name_or_path,
shard_filename,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
......@@ -1362,6 +841,8 @@ def get_checkpoint_shard_files(
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
revision=revision,
subfolder=subfolder,
)
# We have already dealt with RepositoryNotFoundError and RevisionNotFoundError when getting the index, so
# we don't have to catch them here.
......
......@@ -26,20 +26,13 @@ import transformers
from transformers import * # noqa F406
from transformers.testing_utils import DUMMY_UNKNOWN_IDENTIFIER
from transformers.utils import (
CONFIG_NAME,
FLAX_WEIGHTS_NAME,
TF2_WEIGHTS_NAME,
WEIGHTS_NAME,
ContextManagers,
EntryNotFoundError,
RepositoryNotFoundError,
RevisionNotFoundError,
filename_to_url,
find_labels,
get_file_from_repo,
get_from_cache,
has_file,
hf_bucket_url,
is_flax_available,
is_tf_available,
is_torch_available,
......@@ -85,60 +78,6 @@ class TestImportMechanisms(unittest.TestCase):
class GetFromCacheTests(unittest.TestCase):
def test_bogus_url(self):
# This lets us simulate no connection
# as the error raised is the same
# `ConnectionError`
url = "https://bogus"
with self.assertRaisesRegex(ValueError, "Connection error"):
_ = get_from_cache(url)
def test_file_not_found(self):
# Valid revision (None) but missing file.
url = hf_bucket_url(MODEL_ID, filename="missing.bin")
with self.assertRaisesRegex(EntryNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
def test_model_not_found_not_authenticated(self):
# Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "401 Client Error"):
_ = get_from_cache(url)
@unittest.skip("No authentication when testing against prod")
def test_model_not_found_authenticated(self):
# Invalid model id.
url = hf_bucket_url("bert-base", filename="pytorch_model.bin")
with self.assertRaisesRegex(RepositoryNotFoundError, "404 Client Error"):
_ = get_from_cache(url, use_auth_token="hf_sometoken")
# ^ TODO - if we decide to unskip this: use a real / functional token
def test_revision_not_found(self):
# Valid file but missing revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_INVALID)
with self.assertRaisesRegex(RevisionNotFoundError, "404 Client Error"):
_ = get_from_cache(url)
def test_standard_object(self):
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA1}"'))
def test_standard_object_rev(self):
# Same object, but different revision
url = hf_bucket_url(MODEL_ID, filename=CONFIG_NAME, revision=REVISION_ID_ONE_SPECIFIC_COMMIT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertNotEqual(metadata[1], f'"{PINNED_SHA1}"')
# Caution: check that the etag is *not* equal to the one from `test_standard_object`
def test_lfs_object(self):
url = hf_bucket_url(MODEL_ID, filename=WEIGHTS_NAME, revision=REVISION_ID_DEFAULT)
filepath = get_from_cache(url, force_download=True)
metadata = filename_to_url(filepath)
self.assertEqual(metadata, (url, f'"{PINNED_SHA256}"'))
def test_has_file(self):
self.assertTrue(has_file("hf-internal-testing/tiny-bert-pt-only", WEIGHTS_NAME))
self.assertFalse(has_file("hf-internal-testing/tiny-bert-pt-only", TF2_WEIGHTS_NAME))
......
......@@ -614,7 +614,6 @@ UNDOCUMENTED_OBJECTS = [
"absl", # External module
"add_end_docstrings", # Internal, should never have been in the main init.
"add_start_docstrings", # Internal, should never have been in the main init.
"cached_path", # Internal used for downloading models.
"convert_tf_weight_name_to_pt_weight_name", # Internal used to convert model weights
"logger", # Internal logger
"logging", # External module
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment