Unverified Commit d6249d06 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix typing for `safetensors_load_strategy` (#24641)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 25bb9e8c
......@@ -51,7 +51,7 @@ class LoadConfig:
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
safetensors_load_strategy: Optional[str] = "lazy"
safetensors_load_strategy: str = "lazy"
"""Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage.
......
......@@ -289,8 +289,7 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: Optional[
str] = LoadConfig.safetensors_load_strategy
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
......
......@@ -519,7 +519,7 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
safetensors_load_strategy: Optional[str] = "lazy",
safetensors_load_strategy: str = "lazy",
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
loading_desc = "Loading safetensors checkpoint shards"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment