Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
873ae12c
Unverified
Commit
873ae12c
authored
Jun 17, 2025
by
KavioYu
Committed by
GitHub
Jun 16, 2025
Browse files
support custom weight loader for model runner (#7122)
Co-authored-by:
kavioyu
<
kavioyu@tencent.com
>
parent
c64290dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
64 additions
and
0 deletions
+64
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-0
test/srt/test_update_weights_from_tensor.py
test/srt/test_update_weights_from_tensor.py
+34
-0
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
873ae12c
...
...
@@ -93,6 +93,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from
sglang.srt.utils
import
(
MultiprocessingSerializer
,
cpu_has_amx_support
,
dynamic_import
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_bool_env_var
,
...
...
@@ -761,6 +762,9 @@ class ModelRunner:
]
if
load_format
==
"direct"
:
_model_load_weights_direct
(
self
.
model
,
named_tensors
)
elif
load_format
in
self
.
server_args
.
custom_weight_loader
:
custom_loader
=
dynamic_import
(
load_format
)
custom_loader
(
self
.
model
,
named_tensors
)
elif
load_format
is
None
:
self
.
model
.
load_weights
(
named_tensors
)
else
:
...
...
python/sglang/srt/server_args.py
View file @
873ae12c
...
...
@@ -234,6 +234,9 @@ class ServerArgs:
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
pdlb_url
:
Optional
[
str
]
=
None
# For model weight update
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
def
__post_init__
(
self
):
# Expert parallelism
if
self
.
enable_ep_moe
:
...
...
@@ -538,6 +541,9 @@ class ServerArgs:
"1"
if
self
.
disable_outlines_disk_cache
else
"0"
)
if
self
.
custom_weight_loader
is
None
:
self
.
custom_weight_loader
=
[]
def
validate_disagg_tp_size
(
self
,
prefill_tp
:
int
,
decode_tp
:
int
):
larger_tp
=
max
(
decode_tp
,
prefill_tp
)
smaller_tp
=
min
(
decode_tp
,
prefill_tp
)
...
...
@@ -1576,6 +1582,13 @@ class ServerArgs:
default
=
None
,
help
=
"The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer."
,
)
parser
.
add_argument
(
"--custom-weight-loader"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
help
=
"The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func"
,
)
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/srt/utils.py
View file @
873ae12c
...
...
@@ -2340,3 +2340,16 @@ class LazyValue:
self
.
_value
=
self
.
_creator
()
self
.
_creator
=
None
return
self
.
_value
def
dynamic_import
(
func_path
:
str
):
parts
=
func_path
.
split
(
"."
)
if
len
(
parts
)
<
2
:
raise
ValueError
(
"func_path should contain both module name and func name (such as 'module.func')"
)
module_path
=
"."
.
join
(
parts
[:
-
1
])
func_name
=
parts
[
-
1
]
module
=
importlib
.
import_module
(
module_path
)
func
=
getattr
(
module
,
func_name
)
return
func
test/srt/test_update_weights_from_tensor.py
View file @
873ae12c
...
...
@@ -78,6 +78,40 @@ class TestUpdateWeightsFromTensor(CustomTestCase):
engine
.
shutdown
()
def
test_update_weights_from_tensor_load_format_custom
(
self
):
custom_loader_name
=
(
"sglang.srt.model_executor.model_runner._model_load_weights_direct"
)
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
custom_weight_loader
=
[
custom_loader_name
],
)
write_param_names
=
[
f
"model.layers.
{
i
}
.self_attn.qkv_proj.weight"
for
i
in
range
(
6
,
16
)
]
read_param_names
=
[
f
"model.layers.
{
i
}
.self_attn.k_proj.weight"
for
i
in
range
(
6
,
16
)
]
_check_param
(
engine
,
read_param_names
[
0
],
[
-
0.0198
,
0.0227
,
0.0168
,
0.0232
,
-
0.0178
]
)
new_tensor
=
torch
.
full
((
3072
,
2048
),
1.5
)
engine
.
update_weights_from_tensor
(
[
(
write_param_name
,
new_tensor
.
clone
())
for
write_param_name
in
write_param_names
],
load_format
=
custom_loader_name
,
)
for
read_param_name
in
read_param_names
[:
3
]:
_check_param
(
engine
,
read_param_name
,
[
1.5
]
*
5
)
engine
.
shutdown
()
def
_check_param
(
engine
,
param_name
,
expect_values
):
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment