Unverified Commit 100f93d2 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Filter safetensors files to download if .safetensors.index.json exists (#30537)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 96bf50a2
...@@ -23,6 +23,7 @@ import torch ...@@ -23,6 +23,7 @@ import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load, load_file, safe_open, save_file from safetensors.torch import load, load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs from vllm import envs
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -448,10 +449,29 @@ def download_weights_from_hf( ...@@ -448,10 +449,29 @@ def download_weights_from_hf(
fs = HfFileSystem() fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision) file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# If downloading safetensors and an index file exists, use the
# specific file names from the index to avoid downloading
# unnecessary files (e.g., from subdirectories like "original/").
index_file = f"{model_name_or_path}/{SAFE_WEIGHTS_INDEX_NAME}"
if "*.safetensors" in allow_patterns and index_file in file_list:
index_path = hf_hub_download(
repo_id=model_name_or_path,
filename=SAFE_WEIGHTS_INDEX_NAME,
cache_dir=cache_dir,
revision=revision,
)
with open(index_path) as f:
weight_map = json.load(f)["weight_map"]
if weight_map:
# Extra [] so that weight_map files are treated as a
# single allow_pattern in the loop below
allow_patterns = [list(set(weight_map.values()))] # type: ignore[list-item]
else:
allow_patterns = ["*.safetensors"]
else:
# Use the first pattern found in the HF repo's files. # Use the first pattern found in the HF repo's files.
for pattern in allow_patterns: for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern) if fnmatch.filter(file_list, pattern):
if len(matching) > 0:
allow_patterns = [pattern] allow_patterns = [pattern]
break break
except Exception as e: except Exception as e:
...@@ -480,6 +500,9 @@ def download_weights_from_hf( ...@@ -480,6 +500,9 @@ def download_weights_from_hf(
) )
# If we have downloaded weights for this allow_pattern, # If we have downloaded weights for this allow_pattern,
# we don't need to check the rest. # we don't need to check the rest.
# allow_pattern can be a list (from weight_map) or str (glob)
if isinstance(allow_pattern, list):
break
if any(Path(hf_folder).glob(allow_pattern)): if any(Path(hf_folder).glob(allow_pattern)):
break break
time_taken = time.perf_counter() - start_time time_taken = time.perf_counter() - start_time
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment