Unverified Commit 47e16762 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

New cache fixes: add safeguard before looking in folders (#18522)

parent 74959240
...@@ -133,6 +133,8 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]: ...@@ -133,6 +133,8 @@ def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
elif isinstance(cache_dir, Path): elif isinstance(cache_dir, Path):
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if not os.path.isdir(cache_dir):
return []
cached_models = [] cached_models = []
for file in os.listdir(cache_dir): for file in os.listdir(cache_dir):
...@@ -210,6 +212,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None): ...@@ -210,6 +212,9 @@ def try_to_load_from_cache(cache_dir, repo_id, filename, revision=None):
if not os.path.isdir(model_cache): if not os.path.isdir(model_cache):
# No cache for this model # No cache for this model
return None return None
for subfolder in ["refs", "snapshots"]:
if not os.path.isdir(os.path.join(model_cache, subfolder)):
return None
# Resolve refs (for instance to convert main to the associated commit sha) # Resolve refs (for instance to convert main to the associated commit sha)
cached_refs = os.listdir(os.path.join(model_cache, "refs")) cached_refs = os.listdir(os.path.join(model_cache, "refs"))
...@@ -873,6 +878,8 @@ def get_all_cached_files(cache_dir=None): ...@@ -873,6 +878,8 @@ def get_all_cached_files(cache_dir=None):
cache_dir = TRANSFORMERS_CACHE cache_dir = TRANSFORMERS_CACHE
else: else:
cache_dir = str(cache_dir) cache_dir = str(cache_dir)
if not os.path.isdir(cache_dir):
return []
cached_files = [] cached_files = []
for file in os.listdir(cache_dir): for file in os.listdir(cache_dir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment