"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "e0d7c831c7691a0069a57ba03993a8d531343de1"
Unverified Commit 71ff88fa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Further reduce the number of alls to head for cached objects (#18871)

* Further reduce the number of alls to head for cached models/tokenizers/pipelines

* Fix tests

* Address review comments
parent 6678350c
...@@ -120,6 +120,9 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R ...@@ -120,6 +120,9 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", HUGGINGFACE_CO_R
HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}" HUGGINGFACE_CO_PREFIX = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/{model_id}/resolve/{revision}/{filename}"
HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples" HUGGINGFACE_CO_EXAMPLES_TELEMETRY = HUGGINGFACE_CO_RESOLVE_ENDPOINT + "/api/telemetry/examples"
# Return value when trying to load a file from cache but the file does not exist in the distant repo.
_CACHED_NO_EXIST = object()
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]: def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
""" """
...@@ -222,6 +225,22 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str] ...@@ -222,6 +225,22 @@ def extract_commit_hash(resolved_file: Optional[str], commit_hash: Optional[str]
def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None): def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_hash=None):
""" """
Explores the cache to return the latest cached file for a given revision. Explores the cache to return the latest cached file for a given revision.
Args:
cache_dir (`str` or `os.PathLike`): The folder where the cached files lie.
repo_id (`str`): The ID of the repo on huggingface.co.
filename (`str`): The filename to look for inside `repo_id`.
revision (`str`, *optional*):
The specific model version to use. Will default to `"main"` if it's not provided and no `commit_hash` is
provided either.
commit_hash (`str`, *optional*): The (full) commit hash to look for inside the cache.
Returns:
`Optional[str]` or `_CACHED_NO_EXIST`:
Will return `None` if the file was not cached. Otherwise:
- The exact path to the cached file if it's found in the cache
- A special value `_CACHED_NO_EXIST` if the file does not exist at the given commit hash and this fact was
cached.
""" """
if commit_hash is not None and revision is not None: if commit_hash is not None and revision is not None:
raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.") raise ValueError("`commit_hash` and `revision` are mutually exclusive, pick one only.")
...@@ -244,6 +263,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h ...@@ -244,6 +263,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None, commit_h
with open(os.path.join(model_cache, "refs", revision)) as f: with open(os.path.join(model_cache, "refs", revision)) as f:
commit_hash = f.read() commit_hash = f.read()
if os.path.isfile(os.path.join(model_cache, ".no_exist", commit_hash, filename)):
return _CACHED_NO_EXIST
cached_shas = os.listdir(os.path.join(model_cache, "snapshots")) cached_shas = os.listdir(os.path.join(model_cache, "snapshots"))
if commit_hash not in cached_shas: if commit_hash not in cached_shas:
# No cache for this revision and we won't try to return a random revision # No cache for this revision and we won't try to return a random revision
...@@ -338,7 +360,10 @@ def cached_file( ...@@ -338,7 +360,10 @@ def cached_file(
resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename) resolved_file = os.path.join(os.path.join(path_or_repo_id, subfolder), filename)
if not os.path.isfile(resolved_file): if not os.path.isfile(resolved_file):
if _raise_exceptions_for_missing_entries: if _raise_exceptions_for_missing_entries:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.") raise EnvironmentError(
f"{path_or_repo_id} does not appear to have a file named {full_filename}. Checkout "
f"'https://huggingface.co/{path_or_repo_id}/{revision}' for available files."
)
else: else:
return None return None
return resolved_file return resolved_file
...@@ -352,7 +377,12 @@ def cached_file( ...@@ -352,7 +377,12 @@ def cached_file(
# If the file is cached under that commit hash, we return it directly. # If the file is cached under that commit hash, we return it directly.
resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash) resolved_file = try_to_load_from_cache(cache_dir, path_or_repo_id, full_filename, commit_hash=_commit_hash)
if resolved_file is not None: if resolved_file is not None:
if resolved_file is not _CACHED_NO_EXIST:
return resolved_file return resolved_file
elif not _raise_exceptions_for_missing_entries:
return None
else:
raise EnvironmentError(f"Could not locate {full_filename} inside {path_or_repo_id}.")
user_agent = http_user_agent(user_agent) user_agent = http_user_agent(user_agent)
try: try:
......
...@@ -370,6 +370,5 @@ class AutoModelTest(unittest.TestCase): ...@@ -370,6 +370,5 @@ class AutoModelTest(unittest.TestCase):
with RequestCounter() as counter: with RequestCounter() as counter:
_ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded") _ = AutoModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one. self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.other_request_count, 0)
...@@ -303,6 +303,5 @@ class TFAutoModelTest(unittest.TestCase): ...@@ -303,6 +303,5 @@ class TFAutoModelTest(unittest.TestCase):
with RequestCounter() as counter: with RequestCounter() as counter:
_ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded") _ = TFAutoModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter.get_request_count, 0)
# There is no pytorch_model.bin so we still get one call for this one. self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.other_request_count, 0)
...@@ -349,6 +349,5 @@ class AutoTokenizerTest(unittest.TestCase): ...@@ -349,6 +349,5 @@ class AutoTokenizerTest(unittest.TestCase):
with RequestCounter() as counter: with RequestCounter() as counter:
_ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert") _ = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.other_request_count, 0)
...@@ -884,8 +884,7 @@ class CustomPipelineTest(unittest.TestCase): ...@@ -884,8 +884,7 @@ class CustomPipelineTest(unittest.TestCase):
with RequestCounter() as counter: with RequestCounter() as counter:
_ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert") _ = pipeline("text-classification", model="hf-internal-testing/tiny-random-bert")
self.assertEqual(counter.get_request_count, 0) self.assertEqual(counter.get_request_count, 0)
# We still have one extra call because the model does not have a added_tokens.json file self.assertEqual(counter.head_request_count, 1)
self.assertEqual(counter.head_request_count, 2)
self.assertEqual(counter.other_request_count, 0) self.assertEqual(counter.other_request_count, 0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment