Unverified Commit 97d966a7 authored by Mick's avatar Mick Committed by GitHub
Browse files

ci: make find_local_hf_snapshot_dir more robust (#11248)

parent 8e66d87f
...@@ -8,6 +8,7 @@ import hashlib ...@@ -8,6 +8,7 @@ import hashlib
import json import json
import logging import logging
import os import os
import re
import tempfile import tempfile
from collections import defaultdict from collections import defaultdict
from typing import ( from typing import (
...@@ -283,7 +284,24 @@ def find_local_hf_snapshot_dir( ...@@ -283,7 +284,24 @@ def find_local_hf_snapshot_dir(
except Exception as e: except Exception as e:
logger.warning("Failed to find local snapshot in default HF cache: %s", e) logger.warning("Failed to find local snapshot in default HF cache: %s", e)
# If local snapshot exists, validate it contains at least one weight file # if any incomplete file exists, force re-download by returning None
if found_local_snapshot_dir:
repo_folder = os.path.abspath(
os.path.join(found_local_snapshot_dir, "..", "..")
)
blobs_dir = os.path.join(repo_folder, "blobs")
if os.path.isdir(blobs_dir) and glob.glob(
os.path.join(blobs_dir, "*.incomplete")
):
logger.info(
"Found .incomplete files in %s for %s. "
"Considering local snapshot incomplete.",
blobs_dir,
model_name_or_path,
)
return None
# if local snapshot exists, validate it contains at least one weight file
# matching allow_patterns before skipping download. # matching allow_patterns before skipping download.
if found_local_snapshot_dir is None: if found_local_snapshot_dir is None:
return None return None
...@@ -291,9 +309,12 @@ def find_local_hf_snapshot_dir( ...@@ -291,9 +309,12 @@ def find_local_hf_snapshot_dir(
local_weight_files: List[str] = [] local_weight_files: List[str] = []
try: try:
for pattern in allow_patterns: for pattern in allow_patterns:
local_weight_files.extend( matched_files = glob.glob(os.path.join(found_local_snapshot_dir, pattern))
glob.glob(os.path.join(found_local_snapshot_dir, pattern)) for f in matched_files:
) # os.path.exists returns False for broken symlinks.
if not os.path.exists(f):
continue
local_weight_files.append(f)
except Exception as e: except Exception as e:
logger.warning( logger.warning(
"Failed to scan local snapshot %s with patterns %s: %s", "Failed to scan local snapshot %s with patterns %s: %s",
...@@ -303,6 +324,46 @@ def find_local_hf_snapshot_dir( ...@@ -303,6 +324,46 @@ def find_local_hf_snapshot_dir(
) )
local_weight_files = [] local_weight_files = []
# After we have a list of valid files, check for sharded model completeness.
# Check if all safetensors with name model-{i}-of-{n}.safetensors exists
checked_sharded_model = False
for f in local_weight_files:
if checked_sharded_model:
break
base_name = os.path.basename(f)
# Regex for files like model-00001-of-00009.safetensors
match = re.match(r"(.*?)-([0-9]+)-of-([0-9]+)\.(.*)", base_name)
if match:
prefix = match.group(1)
shard_id_str = match.group(2)
total_shards_str = match.group(3)
suffix = match.group(4)
total_shards = int(total_shards_str)
# Check if all shards are present
missing_shards = []
for i in range(1, total_shards + 1):
# Reconstruct shard name, preserving padding of original shard id
shard_name = (
f"{prefix}-{i:0{len(shard_id_str)}d}-of-{total_shards_str}.{suffix}"
)
expected_path = os.path.join(found_local_snapshot_dir, shard_name)
# os.path.exists returns False for broken symlinks, which is desired.
if not os.path.exists(expected_path):
missing_shards.append(shard_name)
if missing_shards:
logger.info(
"Found incomplete sharded model %s. Missing shards: %s. "
"Will attempt download.",
model_name_or_path,
missing_shards,
)
return None
# If we found and verified one set of shards, we are done.
checked_sharded_model = True
if len(local_weight_files) > 0: if len(local_weight_files) > 0:
logger.info( logger.info(
"Found local HF snapshot for %s at %s; skipping download.", "Found local HF snapshot for %s at %s; skipping download.",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment