Unverified Commit 71ff88fa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Further reduce the number of alls to head for cached objects (#18871)

* Further reduce the number of alls to head for cached models/tokenizers/pipelines

* Fix tests

* Address review comments
parent 6678350c
......@@ -120,6 +120,9 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
......@@ -222,6 +225,22 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
"""
Explores the cache to return the latest cached file for a given revision.
Args:
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
repo_id (`str`): The ID of the repo on huggingface.co.
filename (`str`): The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
Will return `None` if the file was not cached. Otherwise:
- The exact path to the cached file if it's found in the cache
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
"""
if commit_hash is not None and revision is not None:
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
......@@ -244,6 +263,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
with open(os.path.join(model_cache, "refs", revision)) as f:
commit_hash = f.read()
if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)):
return _CACHED_NO_EXIST
cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if commit_hash not in cached_shas:
# No cache for this revision and we won't try to return a random revision
......@@ -338,7 +360,10 @@ def cached_file(
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
)
else:
return None
return resolved_file
......@@ -352,7 +377,12 @@ def cached_file(
# If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
return resolved_file
elif not _raise_exceptions_for_missing_entries:
return None
else:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
user_agent = http_user_agent(user_agent)
try:
......
......@@ -370,6 +370,5 @@ class AutoModelTest(unittest.TestCase):
with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
......@@ -303,6 +303,5 @@ class TFAutoModelTest(unittest.TestCase):
with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one.
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
......@@ -349,6 +349,5 @@ class AutoTokenizerTest(unittest.TestCase):
with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
......@@ -884,8 +884,7 @@ class CustomPipelineTest(unittest.TestCase):
with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.other_request_count, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment