Unverified Commit e5afb88b authored by Yuhong Guo's avatar Yuhong Guo Committed by GitHub
Browse files

Support weight loading without mmap (#7469)

parent e5ddeb04
...@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [ ...@@ -101,6 +101,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
"torchao_config", "torchao_config",
"triton_attention_reduce_in_fp32", "triton_attention_reduce_in_fp32",
"num_reserved_decode_tokens", "num_reserved_decode_tokens",
"weight_loader_disable_mmap",
] ]
# Put some global args for easy access # Put some global args for easy access
......
...@@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -337,7 +337,14 @@ class DefaultModelLoader(BaseModelLoader):
hf_weights_files, hf_weights_files,
) )
elif use_safetensors: elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files) from sglang.srt.managers.schedule_batch import global_server_args_dict
weight_loader_disable_mmap = global_server_args_dict.get(
"weight_loader_disable_mmap"
)
weights_iterator = safetensors_weights_iterator(
hf_weights_files, disable_mmap=weight_loader_disable_mmap
)
else: else:
weights_iterator = pt_weights_iterator(hf_weights_files) weights_iterator = pt_weights_iterator(hf_weights_files)
......
...@@ -422,6 +422,7 @@ def safetensors_weights_iterator( ...@@ -422,6 +422,7 @@ def safetensors_weights_iterator(
hf_weights_files: List[str], hf_weights_files: List[str],
is_all_weights_sharded: bool = False, is_all_weights_sharded: bool = False,
decryption_key: Optional[str] = None, decryption_key: Optional[str] = None,
disable_mmap: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]: ) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files. """Iterate over the weights in the model safetensor files.
...@@ -443,7 +444,11 @@ def safetensors_weights_iterator( ...@@ -443,7 +444,11 @@ def safetensors_weights_iterator(
disable=not enable_tqdm, disable=not enable_tqdm,
bar_format=_BAR_FORMAT, bar_format=_BAR_FORMAT,
): ):
result = safetensors.torch.load_file(st_file, device="cpu") if disable_mmap:
with open(st_file, "rb") as f:
result = safetensors.torch.load(f.read())
else:
result = safetensors.torch.load_file(st_file, device="cpu")
for name, param in result.items(): for name, param in result.items():
yield name, param yield name, param
......
...@@ -237,6 +237,7 @@ class ServerArgs: ...@@ -237,6 +237,7 @@ class ServerArgs:
# For model weight update # For model weight update
custom_weight_loader: Optional[List[str]] = None custom_weight_loader: Optional[List[str]] = None
weight_loader_disable_mmap: bool = False
def __post_init__(self): def __post_init__(self):
# Expert parallelism # Expert parallelism
...@@ -1599,6 +1600,11 @@ class ServerArgs: ...@@ -1599,6 +1600,11 @@ class ServerArgs:
default=None, default=None,
help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func", help="The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func",
) )
parser.add_argument(
"--weight-loader-disable-mmap",
action="store_true",
help="Disable mmap while loading weight using safetensors.",
)
@classmethod @classmethod
def from_cli_args(cls, args: argparse.Namespace): def from_cli_args(cls, args: argparse.Namespace):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment