Unverified Commit 43fa9f22 authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: check if weights are already local before downloading (#11015)

parent e98d9346
...@@ -8,7 +8,6 @@ import hashlib ...@@ -8,7 +8,6 @@ import hashlib
import json import json
import logging import logging
import os import os
import queue
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import ( from typing import (
...@@ -38,7 +37,8 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank ...@@ -38,7 +37,8 @@ from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.layers.dp_attention import get_attention_tp_rank from sglang.srt.layers.dp_attention import get_attention_tp_rank
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
from sglang.srt.utils import print_warning_once from sglang.srt.utils import find_local_repo_dir, print_warning_once
from sglang.utils import is_in_ci
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -236,6 +236,92 @@ def get_quant_config( ...@@ -236,6 +236,92 @@ def get_quant_config(
return quant_cls.from_config(config) return quant_cls.from_config(config)
def find_local_hf_snapshot_dir(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: List[str],
revision: Optional[str] = None,
) -> Optional[str]:
"""If the weights are already local, skip downloading and returns the path
Only applied in ci
"""
if not is_in_ci() or os.path.isdir(model_name_or_path):
return None
found_local_snapshot_dir = None
# Check custom cache_dir (if provided)
if cache_dir:
try:
repo_folder = os.path.join(
cache_dir,
huggingface_hub.constants.REPO_ID_SEPARATOR.join(
["models", *model_name_or_path.split("/")]
),
)
rev_to_use = revision
if not rev_to_use:
ref_main = os.path.join(repo_folder, "refs", "main")
if os.path.isfile(ref_main):
with open(ref_main) as f:
rev_to_use = f.read().strip()
if rev_to_use:
rev_dir = os.path.join(repo_folder, "snapshots", rev_to_use)
if os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning(
"Failed to find local snapshot in custom cache_dir %s: %s",
cache_dir,
e,
)
# Check default HF cache as well
if not found_local_snapshot_dir:
try:
rev_dir = find_local_repo_dir(model_name_or_path, revision)
if rev_dir and os.path.isdir(rev_dir):
found_local_snapshot_dir = rev_dir
except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
# If local snapshot exists, validate it contains at least one weight file
# matching allow_patterns before skipping download.
if found_local_snapshot_dir is None:
return None
local_weight_files: List[str] = []
try:
for pattern in allow_patterns:
local_weight_files.extend(
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
)
except Exception as e:
logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s",
found_local_snapshot_dir,
allow_patterns,
e,
)
local_weight_files = []
if len(local_weight_files) > 0:
logger.info(
"Found local HF snapshot for %s at %s; skipping download.",
model_name_or_path,
found_local_snapshot_dir,
)
return found_local_snapshot_dir
else:
logger.info(
"Local HF snapshot at %s has no files matching %s; will attempt download.",
found_local_snapshot_dir,
allow_patterns,
)
return None
def download_weights_from_hf( def download_weights_from_hf(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str], cache_dir: Optional[str],
...@@ -260,6 +346,13 @@ def download_weights_from_hf( ...@@ -260,6 +346,13 @@ def download_weights_from_hf(
Returns: Returns:
str: The path to the downloaded model weights. str: The path to the downloaded model weights.
""" """
path = find_local_hf_snapshot_dir(
model_name_or_path, cache_dir, allow_patterns, revision
)
if path is not None:
return path
if not huggingface_hub.constants.HF_HUB_OFFLINE: if not huggingface_hub.constants.HF_HUB_OFFLINE:
# Before we download we look at that is available: # Before we download we look at that is available:
fs = HfFileSystem() fs = HfFileSystem()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment