Unverified Commit c581d8af authored by Charles's avatar Charles Committed by GitHub
Browse files

Add utility function for retrieving locally cached models (#8836)



* add get_cached_models function

* add List type to import

* fix code quality

* Update src/transformers/file_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/file_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix style
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8eb7f26d
......@@ -32,7 +32,7 @@ from dataclasses import fields
from functools import partial, wraps
from hashlib import sha256
from pathlib import Path
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile
......@@ -1030,6 +1030,39 @@ def filename_to_url(filename, cache_dir=None):
return url, etag
def get_cached_models(cache_dir: Union[str, Path] = None) -> List[Tuple]:
"""
Returns a list of tuples representing model binaries that are cached locally. Each tuple has shape
:obj:`(model_url, etag, size_MB)`. Filenames in :obj:`cache_dir` are use to get the metadata for each model, only
urls ending with `.bin` are added.
Args:
cache_dir (:obj:`Union[str, Path]`, `optional`):
The cache directory to search for models within. Will default to the transformers cache if unset.
Returns:
List[Tuple]: List of tuples each with shape :obj:`(model_url, etag, size_MB)`
"""
if cache_dir is None:
cache_dir = TRANSFORMERS_CACHE
elif isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
cached_models = []
for file in os.listdir(cache_dir):
if file.endswith(".json"):
meta_path = os.path.join(cache_dir, file)
with open(meta_path, encoding="utf-8") as meta_file:
metadata = json.load(meta_file)
url = metadata["url"]
etag = metadata["etag"]
if url.endswith(".bin"):
size_MB = os.path.getsize(meta_path.strip(".json")) / 1e6
cached_models.append((url, etag, size_MB))
return cached_models
def cached_path(
url_or_filename,
cache_dir=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment