Unverified Commit 41329a0f authored by shengshiqi-google's avatar shengshiqi-google Committed by GitHub
Browse files

[Core] feat: Add --safetensors-load-strategy flag for faster safetensors...


[Core] feat: Add --safetensors-load-strategy flag for faster safetensors loading from Lustre (#24469)
Signed-off-by: default avatarShiqi Sheng <shengshiqi@google.com>
Signed-off-by: default avatarshengshiqi-google <160179165+shengshiqi-google@users.noreply.github.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent ee0bc5e1
......@@ -51,6 +51,15 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
safetensors_load_strategy: Optional[str] = "lazy"
"""Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage.
- "eager": The entire file is read into CPU memory upfront before loading.
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
"""
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
......
......@@ -289,6 +289,8 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: Optional[
str] = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
......@@ -587,6 +589,8 @@ class EngineArgs:
load_group.add_argument("--load-format", **load_kwargs["load_format"])
load_group.add_argument("--download-dir",
**load_kwargs["download_dir"])
load_group.add_argument("--safetensors-load-strategy",
**load_kwargs["safetensors_load_strategy"])
load_group.add_argument("--model-loader-extra-config",
**load_kwargs["model_loader_extra_config"])
load_group.add_argument("--ignore-patterns",
......@@ -1023,6 +1027,7 @@ class EngineArgs:
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
safetensors_load_strategy=self.safetensors_load_strategy,
device="cpu"
if is_online_quantization(self.quantization) else None,
model_loader_extra_config=self.model_loader_extra_config,
......
......@@ -189,6 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
)
else:
if extra_config.get("enable_multithread_load"):
......
......@@ -19,7 +19,7 @@ import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from safetensors.torch import load, load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm import envs
......@@ -519,14 +519,24 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
safetensors_load_strategy: Optional[str] = "lazy",
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
loading_desc = "Loading safetensors checkpoint shards"
if safetensors_load_strategy == "eager":
loading_desc += " (eager)"
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
desc=loading_desc,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
if safetensors_load_strategy == "eager":
with open(st_file, "rb") as f:
state_dict = load(f.read())
yield from state_dict.items()
else:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment