Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9f1787fa
Unverified
Commit
9f1787fa
authored
Jun 25, 2025
by
xianzhiT
Committed by
GitHub
Jun 24, 2025
Browse files
Support multi-thread model weight loading (#7277)
parent
8ecad0b1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
143 additions
and
10 deletions
+143
-10
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
python/sglang/srt/model_loader/loader.py
python/sglang/srt/model_loader/loader.py
+45
-10
python/sglang/srt/model_loader/weight_utils.py
python/sglang/srt/model_loader/weight_utils.py
+89
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+8
-0
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
9f1787fa
...
@@ -547,6 +547,7 @@ class ModelRunner:
...
@@ -547,6 +547,7 @@ class ModelRunner:
self
.
load_config
=
LoadConfig
(
self
.
load_config
=
LoadConfig
(
load_format
=
self
.
server_args
.
load_format
,
load_format
=
self
.
server_args
.
load_format
,
download_dir
=
self
.
server_args
.
download_dir
,
download_dir
=
self
.
server_args
.
download_dir
,
model_loader_extra_config
=
self
.
server_args
.
model_loader_extra_config
,
)
)
if
self
.
server_args
.
load_format
==
"gguf"
:
if
self
.
server_args
.
load_format
==
"gguf"
:
monkey_patch_vllm_gguf_config
()
monkey_patch_vllm_gguf_config
()
...
...
python/sglang/srt/model_loader/loader.py
View file @
9f1787fa
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
# ruff: noqa: SIM117
# ruff: noqa: SIM117
import
collections
import
collections
import
concurrent
import
dataclasses
import
dataclasses
import
fnmatch
import
fnmatch
import
glob
import
glob
...
@@ -11,14 +12,17 @@ import math
...
@@ -11,14 +12,17 @@ import math
import
os
import
os
import
time
import
time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
concurrent.futures
import
ThreadPoolExecutor
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
from
typing
import
Any
,
Dict
,
Generator
,
Iterable
,
List
,
Optional
,
Tuple
,
cast
import
huggingface_hub
import
huggingface_hub
import
numpy
as
np
import
numpy
as
np
import
safetensors.torch
import
torch
import
torch
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
huggingface_hub
import
HfApi
,
hf_hub_download
from
torch
import
nn
from
torch
import
nn
from
tqdm.auto
import
tqdm
from
transformers
import
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
...
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
...
@@ -41,6 +45,7 @@ from sglang.srt.model_loader.utils import (
set_default_torch_dtype
,
set_default_torch_dtype
,
)
)
from
sglang.srt.model_loader.weight_utils
import
(
from
sglang.srt.model_loader.weight_utils
import
(
_BAR_FORMAT
,
download_safetensors_index_file_from_hf
,
download_safetensors_index_file_from_hf
,
download_weights_from_hf
,
download_weights_from_hf
,
filter_duplicate_safetensors_files
,
filter_duplicate_safetensors_files
,
...
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
...
@@ -49,6 +54,8 @@ from sglang.srt.model_loader.weight_utils import (
get_quant_config
,
get_quant_config
,
gguf_quant_weights_iterator
,
gguf_quant_weights_iterator
,
initialize_dummy_weights
,
initialize_dummy_weights
,
multi_thread_pt_weights_iterator
,
multi_thread_safetensors_weights_iterator
,
np_cache_weights_iterator
,
np_cache_weights_iterator
,
pt_weights_iterator
,
pt_weights_iterator
,
safetensors_weights_iterator
,
safetensors_weights_iterator
,
...
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
...
@@ -181,6 +188,9 @@ class BaseModelLoader(ABC):
class
DefaultModelLoader
(
BaseModelLoader
):
class
DefaultModelLoader
(
BaseModelLoader
):
"""Model loader that can load different file types from disk."""
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS
=
8
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Source
:
class
Source
:
"""A source for weights."""
"""A source for weights."""
...
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -208,10 +218,15 @@ class DefaultModelLoader(BaseModelLoader):
def
__init__
(
self
,
load_config
:
LoadConfig
):
def
__init__
(
self
,
load_config
:
LoadConfig
):
super
().
__init__
(
load_config
)
super
().
__init__
(
load_config
)
if
load_config
.
model_loader_extra_config
:
extra_config
=
load_config
.
model_loader_extra_config
allowed_keys
=
{
"enable_multithread_load"
,
"num_threads"
}
unexpected_keys
=
set
(
extra_config
.
keys
())
-
allowed_keys
if
unexpected_keys
:
raise
ValueError
(
raise
ValueError
(
f
"Model loader extra config is not supported for "
f
"Unexpected extra config keys for load format "
f
"load format
{
load_config
.
load_format
}
"
f
"
{
load_config
.
load_format
}
: "
f
"
{
unexpected_keys
}
"
)
)
def
_maybe_download_from_modelscope
(
def
_maybe_download_from_modelscope
(
...
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -324,6 +339,7 @@ class DefaultModelLoader(BaseModelLoader):
self
,
source
:
"Source"
self
,
source
:
"Source"
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Get an iterator for the model weights based on the load format."""
"""Get an iterator for the model weights based on the load format."""
extra_config
=
self
.
load_config
.
model_loader_extra_config
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
hf_folder
,
hf_weights_files
,
use_safetensors
=
self
.
_prepare_weights
(
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
source
.
model_or_path
,
source
.
revision
,
source
.
fall_back_to_pt
)
)
...
@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -342,11 +358,30 @@ class DefaultModelLoader(BaseModelLoader):
weight_loader_disable_mmap
=
global_server_args_dict
.
get
(
weight_loader_disable_mmap
=
global_server_args_dict
.
get
(
"weight_loader_disable_mmap"
"weight_loader_disable_mmap"
)
)
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
disable_mmap
=
weight_loader_disable_mmap
if
extra_config
.
get
(
"enable_multithread_load"
):
)
weights_iterator
=
multi_thread_safetensors_weights_iterator
(
hf_weights_files
,
max_workers
=
extra_config
.
get
(
"num_threads"
,
self
.
DEFAULT_NUM_THREADS
),
disable_mmap
=
weight_loader_disable_mmap
,
)
else
:
weights_iterator
=
safetensors_weights_iterator
(
hf_weights_files
,
disable_mmap
=
weight_loader_disable_mmap
)
else
:
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
if
extra_config
.
get
(
"enable_multithread_load"
):
weights_iterator
=
multi_thread_pt_weights_iterator
(
hf_weights_files
,
max_workers
=
extra_config
.
get
(
"num_threads"
,
self
.
DEFAULT_NUM_THREADS
),
)
else
:
weights_iterator
=
pt_weights_iterator
(
hf_weights_files
)
# Apply the prefix.
# Apply the prefix.
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
return
((
source
.
prefix
+
name
,
tensor
)
for
(
name
,
tensor
)
in
weights_iterator
)
...
@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
...
@@ -385,9 +420,9 @@ class DefaultModelLoader(BaseModelLoader):
self
.
load_config
,
self
.
load_config
,
)
)
self
.
load_weights_and_postprocess
(
self
.
load_weights_and_postprocess
(
model
,
self
.
_get_all_weights
(
model_config
,
model
),
target_device
model
,
self
.
_get_all_weights
(
model_config
,
model
),
target_device
)
)
return
model
.
eval
()
return
model
.
eval
()
...
...
python/sglang/srt/model_loader/weight_utils.py
View file @
9f1787fa
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/model_loader/weight_utils.py
"""Utilities for downloading and initializing model weights."""
"""Utilities for downloading and initializing model weights."""
import
concurrent.futures
import
fnmatch
import
fnmatch
import
glob
import
glob
import
hashlib
import
hashlib
import
json
import
json
import
logging
import
logging
import
os
import
os
import
queue
import
tempfile
import
tempfile
from
collections
import
defaultdict
from
collections
import
defaultdict
from
typing
import
(
from
typing
import
(
...
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
...
@@ -453,6 +455,60 @@ def safetensors_weights_iterator(
yield
name
,
param
yield
name
,
param
def
multi_thread_safetensors_weights_iterator
(
hf_weights_files
:
List
[
str
],
is_all_weights_sharded
:
bool
=
False
,
decryption_key
:
Optional
[
str
]
=
None
,
max_workers
:
int
=
4
,
disable_mmap
:
bool
=
False
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Multi-Thread iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
if
decryption_key
:
logger
.
warning
(
"Multi-Thread loading is not working for encrypted safetensor weights."
)
yield
from
safetensors_encrypted_weights_iterator
(
hf_weights_files
,
is_all_weights_sharded
,
decryption_key
)
return
enable_tqdm
=
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
)
def
_load_file
(
st_file
:
str
):
if
disable_mmap
:
with
open
(
st_file
,
"rb"
)
as
f
:
result
=
safetensors
.
torch
.
load
(
f
.
read
())
else
:
result
=
safetensors
.
torch
.
load_file
(
st_file
,
device
=
"cpu"
)
return
result
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
max_workers
)
as
executor
:
futures
=
[
executor
.
submit
(
_load_file
,
st_file
)
for
st_file
in
hf_weights_files
]
if
enable_tqdm
:
futures_iter
=
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
hf_weights_files
),
desc
=
"Multi-thread loading shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
)
else
:
futures_iter
=
concurrent
.
futures
.
as_completed
(
futures
)
for
future
in
futures_iter
:
state_dict
=
future
.
result
()
for
name
,
param
in
state_dict
.
items
():
yield
name
,
param
def
pt_weights_iterator
(
def
pt_weights_iterator
(
hf_weights_files
:
List
[
str
],
hf_weights_files
:
List
[
str
],
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
...
@@ -471,6 +527,39 @@ def pt_weights_iterator(
...
@@ -471,6 +527,39 @@ def pt_weights_iterator(
del
state
del
state
def
multi_thread_pt_weights_iterator
(
hf_weights_files
:
List
[
str
],
max_workers
:
int
=
4
,
)
->
Generator
[
Tuple
[
str
,
torch
.
Tensor
],
None
,
None
]:
"""Multi-Thread iterate over the weights in the model bin/pt files."""
enable_tqdm
=
(
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_rank
()
==
0
)
def
_load_file
(
bin_file
:
str
):
return
torch
.
load
(
bin_file
,
map_location
=
"cpu"
,
weights_only
=
True
)
with
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
max_workers
)
as
executor
:
futures
=
[
executor
.
submit
(
_load_file
,
bin_file
)
for
bin_file
in
hf_weights_files
]
if
enable_tqdm
:
futures_iter
=
tqdm
(
concurrent
.
futures
.
as_completed
(
futures
),
total
=
len
(
hf_weights_files
),
desc
=
"Multi-thread loading pt checkpoint shards"
,
disable
=
not
enable_tqdm
,
bar_format
=
_BAR_FORMAT
,
)
else
:
futures_iter
=
concurrent
.
futures
.
as_completed
(
futures
)
for
future
in
futures_iter
:
state
=
future
.
result
()
yield
from
state
.
items
()
def
get_gguf_extra_tensor_names
(
def
get_gguf_extra_tensor_names
(
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
gguf_file
:
str
,
gguf_to_hf_name_map
:
Dict
[
str
,
str
]
)
->
List
[
str
]:
)
->
List
[
str
]:
...
...
python/sglang/srt/server_args.py
View file @
9f1787fa
...
@@ -47,6 +47,7 @@ class ServerArgs:
...
@@ -47,6 +47,7 @@ class ServerArgs:
tokenizer_mode
:
str
=
"auto"
tokenizer_mode
:
str
=
"auto"
skip_tokenizer_init
:
bool
=
False
skip_tokenizer_init
:
bool
=
False
load_format
:
str
=
"auto"
load_format
:
str
=
"auto"
model_loader_extra_config
:
str
=
"{}"
trust_remote_code
:
bool
=
False
trust_remote_code
:
bool
=
False
dtype
:
str
=
"auto"
dtype
:
str
=
"auto"
kv_cache_dtype
:
str
=
"auto"
kv_cache_dtype
:
str
=
"auto"
...
@@ -632,6 +633,13 @@ class ServerArgs:
...
@@ -632,6 +633,13 @@ class ServerArgs:
"layer before loading another to make the peak memory envelope "
"layer before loading another to make the peak memory envelope "
"smaller."
,
"smaller."
,
)
)
parser
.
add_argument
(
"--model-loader-extra-config"
,
type
=
str
,
help
=
"Extra config for model loader. "
"This will be passed to the model loader corresponding to the chosen load_format."
,
default
=
ServerArgs
.
model_loader_extra_config
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--trust-remote-code"
,
"--trust-remote-code"
,
action
=
"store_true"
,
action
=
"store_true"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment