Unverified Commit 43fa9f22 authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: check if weights are already local before downloading (#11015)

parent e98d9346
......@@ -8,7 +8,6 @@ import hashlib
import json
import logging
import os
import queue
import tempfile
from collections import defaultdict
from typing import (
......@@ -38,7 +37,8 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
from sglang.srt.utils import print_warning_once
from sglang.srt.utils import find_local_repo_dir, print_warning_once
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__)
......@@ -236,6 +236,92 @@ def get_quant_config(
return quant_cls.from_config(config)
def find_local_hf_snapshot_dir(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
) -> Optional[str]:
"""If the weights are already local, skip downloading and returns the path
Only applied in ci
"""
if not is_in_ci() or os.path.isdir(model_name_or_path):
return None
found_local_snapshot_dir = None
# Check custom cache_dir (if provided)
if cache_dir:
try:
repo_folder = os.path.join(
cache_dir,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *model_name_or_path.split("/")]
),
)
rev_to_use = revision
if not rev_to_use:
ref_main = os.path.join(repo_folder, "refs", "main")
if os.path.isfile(ref_main):
with open(ref_main) as f:
rev_to_use = f.read().strip()
if rev_to_use:
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
if os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning(
"Failed to find local snapshot in custom cache_dir %s: %s",
cache_dir,
e,
)
# Check default HF cache as well
if not found_local_snapshot_dir:
try:
rev_dir = find_local_repo_dir(model_name_or_path, revision)
if rev_dir and os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
# If local snapshot exists, validate it contains at least one weight file
# matching allow_patterns before skipping download.
if found_local_snapshot_dir is None:
return None
local_weight_files: List[str] = []
try:
for pattern in allow_patterns:
local_weight_files.extend(
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
)
except Exception as e:
logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s",
found_local_snapshot_dir,
allow_patterns,
e,
)
local_weight_files = []
if len(local_weight_files) > 0:
logger.info(
"Found local HF snapshot for %s at %s; skipping download.",
model_name_or_path,
found_local_snapshot_dir,
)
return found_local_snapshot_dir
else:
logger.info(
"Local HF snapshot at %s has no files matching %s; will attempt download.",
found_local_snapshot_dir,
allow_patterns,
)
return None
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
......@@ -260,6 +346,13 @@ def download_weights_from_hf(
Returns:
str: The path to the downloaded model weights.
"""
path = find_local_hf_snapshot_dir(
model_name_or_path, cache_dir, allow_patterns, revision
)
if path is not None:
return path
if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available:
fs = HfFileSystem()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment