Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
873ae12c
Unverified
Commit
873ae12c
authored
Jun 17, 2025
by
KavioYu
Committed by
GitHub
Jun 16, 2025
Browse files
support custom weight loader for model runner (#7122)
Co-authored-by:
kavioyu
<
kavioyu@tencent.com
>
parent
c64290dc
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
64 additions
and
0 deletions
+64
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+4
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+13
-0
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+13
-0
test/srt/test_update_weights_from_tensor.py
test/srt/test_update_weights_from_tensor.py
+34
-0
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
873ae12c
...
@@ -93,6 +93,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
...
@@ -93,6 +93,7 @@ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
MultiprocessingSerializer
,
MultiprocessingSerializer
,
cpu_has_amx_support
,
cpu_has_amx_support
,
dynamic_import
,
enable_show_time_cost
,
enable_show_time_cost
,
get_available_gpu_memory
,
get_available_gpu_memory
,
get_bool_env_var
,
get_bool_env_var
,
...
@@ -761,6 +762,9 @@ class ModelRunner:
...
@@ -761,6 +762,9 @@ class ModelRunner:
]
]
if
load_format
==
"direct"
:
if
load_format
==
"direct"
:
_model_load_weights_direct
(
self
.
model
,
named_tensors
)
_model_load_weights_direct
(
self
.
model
,
named_tensors
)
elif
load_format
in
self
.
server_args
.
custom_weight_loader
:
custom_loader
=
dynamic_import
(
load_format
)
custom_loader
(
self
.
model
,
named_tensors
)
elif
load_format
is
None
:
elif
load_format
is
None
:
self
.
model
.
load_weights
(
named_tensors
)
self
.
model
.
load_weights
(
named_tensors
)
else
:
else
:
...
...
python/sglang/srt/server_args.py
View file @
873ae12c
...
@@ -234,6 +234,9 @@ class ServerArgs:
...
@@ -234,6 +234,9 @@ class ServerArgs:
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
num_reserved_decode_tokens
:
int
=
512
# used for decode kv cache offload in PD
pdlb_url
:
Optional
[
str
]
=
None
pdlb_url
:
Optional
[
str
]
=
None
# For model weight update
custom_weight_loader
:
Optional
[
List
[
str
]]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
# Expert parallelism
# Expert parallelism
if
self
.
enable_ep_moe
:
if
self
.
enable_ep_moe
:
...
@@ -538,6 +541,9 @@ class ServerArgs:
...
@@ -538,6 +541,9 @@ class ServerArgs:
"1"
if
self
.
disable_outlines_disk_cache
else
"0"
"1"
if
self
.
disable_outlines_disk_cache
else
"0"
)
)
if
self
.
custom_weight_loader
is
None
:
self
.
custom_weight_loader
=
[]
def
validate_disagg_tp_size
(
self
,
prefill_tp
:
int
,
decode_tp
:
int
):
def
validate_disagg_tp_size
(
self
,
prefill_tp
:
int
,
decode_tp
:
int
):
larger_tp
=
max
(
decode_tp
,
prefill_tp
)
larger_tp
=
max
(
decode_tp
,
prefill_tp
)
smaller_tp
=
min
(
decode_tp
,
prefill_tp
)
smaller_tp
=
min
(
decode_tp
,
prefill_tp
)
...
@@ -1576,6 +1582,13 @@ class ServerArgs:
...
@@ -1576,6 +1582,13 @@ class ServerArgs:
default
=
None
,
default
=
None
,
help
=
"The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer."
,
help
=
"The URL of the PD disaggregation load balancer. If set, the prefill/decode server will register with the load balancer."
,
)
)
parser
.
add_argument
(
"--custom-weight-loader"
,
type
=
str
,
nargs
=
"*"
,
default
=
None
,
help
=
"The custom dataloader which used to update the model. Should be set with a valid import path, such as my_package.weight_load_func"
,
)
@
classmethod
@
classmethod
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
def
from_cli_args
(
cls
,
args
:
argparse
.
Namespace
):
...
...
python/sglang/srt/utils.py
View file @
873ae12c
...
@@ -2340,3 +2340,16 @@ class LazyValue:
...
@@ -2340,3 +2340,16 @@ class LazyValue:
self
.
_value
=
self
.
_creator
()
self
.
_value
=
self
.
_creator
()
self
.
_creator
=
None
self
.
_creator
=
None
return
self
.
_value
return
self
.
_value
def
dynamic_import
(
func_path
:
str
):
parts
=
func_path
.
split
(
"."
)
if
len
(
parts
)
<
2
:
raise
ValueError
(
"func_path should contain both module name and func name (such as 'module.func')"
)
module_path
=
"."
.
join
(
parts
[:
-
1
])
func_name
=
parts
[
-
1
]
module
=
importlib
.
import_module
(
module_path
)
func
=
getattr
(
module
,
func_name
)
return
func
test/srt/test_update_weights_from_tensor.py
View file @
873ae12c
...
@@ -78,6 +78,40 @@ class TestUpdateWeightsFromTensor(CustomTestCase):
...
@@ -78,6 +78,40 @@ class TestUpdateWeightsFromTensor(CustomTestCase):
engine
.
shutdown
()
engine
.
shutdown
()
def
test_update_weights_from_tensor_load_format_custom
(
self
):
custom_loader_name
=
(
"sglang.srt.model_executor.model_runner._model_load_weights_direct"
)
engine
=
sgl
.
Engine
(
model_path
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
custom_weight_loader
=
[
custom_loader_name
],
)
write_param_names
=
[
f
"model.layers.
{
i
}
.self_attn.qkv_proj.weight"
for
i
in
range
(
6
,
16
)
]
read_param_names
=
[
f
"model.layers.
{
i
}
.self_attn.k_proj.weight"
for
i
in
range
(
6
,
16
)
]
_check_param
(
engine
,
read_param_names
[
0
],
[
-
0.0198
,
0.0227
,
0.0168
,
0.0232
,
-
0.0178
]
)
new_tensor
=
torch
.
full
((
3072
,
2048
),
1.5
)
engine
.
update_weights_from_tensor
(
[
(
write_param_name
,
new_tensor
.
clone
())
for
write_param_name
in
write_param_names
],
load_format
=
custom_loader_name
,
)
for
read_param_name
in
read_param_names
[:
3
]:
_check_param
(
engine
,
read_param_name
,
[
1.5
]
*
5
)
engine
.
shutdown
()
def
_check_param
(
engine
,
param_name
,
expect_values
):
def
_check_param
(
engine
,
param_name
,
expect_values
):
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
actual_values
=
torch
.
tensor
(
engine
.
get_weights_by_name
(
param_name
))[
0
,
:
5
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment