"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "cdc1fa12eb1ba4795d24e97dcffa2018668a9267"
Unverified Commit 42d5d705 authored by Lumosis's avatar Lumosis Committed by GitHub
Browse files

[Minor] Sort safetensors files to ensure deterministic loading order (#33491)


Signed-off-by: default avatarLihao Ran <imlihao.ran@gmail.com>
Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent 116880a5
...@@ -19,6 +19,7 @@ from typing import IO, Any ...@@ -19,6 +19,7 @@ from typing import IO, Any
import filelock import filelock
import huggingface_hub.constants import huggingface_hub.constants
import numpy as np import numpy as np
import regex as re
import torch import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load, load_file, safe_open, save_file from safetensors.torch import load, load_file, safe_open, save_file
...@@ -143,6 +144,15 @@ def atomic_writer( ...@@ -143,6 +144,15 @@ def atomic_writer(
os.remove(temp_path) os.remove(temp_path)
def _natural_sort_key(filepath: str) -> list:
"""Natural sort key for filenames with numeric components, such as
model-00001-of-00005.safetensors -> ['model-', 1, '-of-', 5, '.safetensors']"""
return [
int(s) if s.isdigit() else s
for s in re.split(r"(\d+)", os.path.basename(filepath))
]
def maybe_download_from_modelscope( def maybe_download_from_modelscope(
model: str, model: str,
revision: str | None = None, revision: str | None = None,
...@@ -682,9 +692,8 @@ def safetensors_weights_iterator( ...@@ -682,9 +692,8 @@ def safetensors_weights_iterator(
loading_desc += " (eager)" loading_desc += " (eager)"
leftover_state_dict: dict[str, torch.Tensor] = {} leftover_state_dict: dict[str, torch.Tensor] = {}
for st_file in tqdm( for st_file in tqdm(
hf_weights_files, sorted(hf_weights_files, key=_natural_sort_key),
desc=loading_desc, desc=loading_desc,
disable=not enable_tqdm(use_tqdm_on_load), disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment