"tests/vscode:/vscode.git/clone" did not exist on "7aa6af1138b206bec10ab3af23a365c0f573b67d"
Unverified Commit 97d966a7 authored by Mick's avatar Mick Committed by GitHub
Browse files

ci: make find_local_hf_snapshot_dir more robust (#11248)

parent 8e66d87f
......@@ -8,6 +8,7 @@ import hashlib
import json
import logging
import os
import re
import tempfile
from collections import defaultdict
from typing import (
......@@ -283,7 +284,24 @@ def find_local_hf_snapshot_dir(
except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e)
# If local snapshot exists, validate it contains at least one weight file
# if any incomplete file exists, force re-download by returning None
if found_local_snapshot_dir:
repo_folder = os.path.abspath(
os.path.join(found_local_snapshot_dir, "..", "..")
)
blobs_dir = os.path.join(repo_folder, "blobs")
if os.path.isdir(blobs_dir) and glob.glob(
os.path.join(blobs_dir, "*.incomplete")
):
logger.info(
"Found .incomplete files in %s for %s. "
"Considering local snapshot incomplete.",
blobs_dir,
model_name_or_path,
)
return None
# if local snapshot exists, validate it contains at least one weight file
# matching allow_patterns before skipping download.
if found_local_snapshot_dir is None:
return None
......@@ -291,9 +309,12 @@ def find_local_hf_snapshot_dir(
local_weight_files: List[str] = []
try:
for pattern in allow_patterns:
local_weight_files.extend(
glob.glob(os.path.join(found_local_snapshot_dir, pattern))
)
matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
for f in matched_files:
# os.path.exists returns False for broken symlinks.
if not os.path.exists(f):
continue
local_weight_files.append(f)
except Exception as e:
logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s",
......@@ -303,6 +324,46 @@ def find_local_hf_snapshot_dir(
)
local_weight_files = []
# After we have a list of valid files, check for sharded model completeness.
# Check if all safetensors with name model-{i}-of-{n}.safetensors exists
checked_sharded_model = False
for f in local_weight_files:
if checked_sharded_model:
break
base_name = os.path.basename(f)
# Regex for files like model-00001-of-00009.safetensors
match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
if match:
prefix = match.group(1)
shard_id_str = match.group(2)
total_shards_str = match.group(3)
suffix = match.group(4)
total_shards = int(total_shards_str)
# Check if all shards are present
missing_shards = []
for i in range(1, total_shards + 1):
# Reconstruct shard name, preserving padding of original shard id
shard_name = (
f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
)
expected_path = os.path.join(found_local_snapshot_dir, shard_name)
# os.path.exists returns False for broken symlinks, which is desired.
if not os.path.exists(expected_path):
missing_shards.append(shard_name)
if missing_shards:
logger.info(
"Found incomplete sharded model %s. Missing shards: %s. "
"Will attempt download.",
model_name_or_path,
missing_shards,
)
return None
# If we found and verified one set of shards, we are done.
checked_sharded_model = True
if len(local_weight_files) > 0:
logger.info(
"Found local HF snapshot for %s at %s; skipping download.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment