Unverified Commit d6249d06 authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Fix typing for `safetensors_load_strategy` (#24641)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 25bb9e8c
...@@ -51,7 +51,7 @@ class LoadConfig: ...@@ -51,7 +51,7 @@ class LoadConfig:
download_dir: Optional[str] = None download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default """Directory to download and load the weights, default to the default
cache directory of Hugging Face.""" cache directory of Hugging Face."""
safetensors_load_strategy: Optional[str] = "lazy" safetensors_load_strategy: str = "lazy"
"""Specifies the loading strategy for safetensors weights. """Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables - "lazy" (default): Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage. on-demand loading and is highly efficient for models on local storage.
......
...@@ -289,8 +289,7 @@ class EngineArgs: ...@@ -289,8 +289,7 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
safetensors_load_strategy: Optional[ safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
str] = LoadConfig.safetensors_load_strategy
load_format: Union[str, LoadFormats] = LoadConfig.load_format load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype dtype: ModelDType = ModelConfig.dtype
......
...@@ -519,7 +519,7 @@ def np_cache_weights_iterator( ...@@ -519,7 +519,7 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator( def safetensors_weights_iterator(
hf_weights_files: list[str], hf_weights_files: list[str],
use_tqdm_on_load: bool, use_tqdm_on_load: bool,
safetensors_load_strategy: Optional[str] = "lazy", safetensors_load_strategy: str = "lazy",
) -> Generator[tuple[str, torch.Tensor], None, None]: ) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.""" """Iterate over the weights in the model safetensor files."""
loading_desc = "Loading safetensors checkpoint shards" loading_desc = "Loading safetensors checkpoint shards"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment