Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0626f678
Unverified
Commit
0626f678
authored
Jul 03, 2025
by
Zilin Zhu
Committed by
GitHub
Jul 02, 2025
Browse files
[RL] support update_weights_from_distributed with different group and multiple weights (#7292)
parent
09e699bb
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
38 deletions
+73
-38
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+13
-4
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+7
-3
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+3
-2
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+1
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+28
-12
test/srt/test_update_weights_from_distributed.py
test/srt/test_update_weights_from_distributed.py
+21
-16
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
0626f678
...
...
@@ -418,12 +418,21 @@ class Engine(EngineBase):
self
.
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
)
def
update_weights_from_distributed
(
self
,
name
:
str
,
dtype
,
shape
):
def
update_weights_from_distributed
(
self
,
names
:
list
[
str
],
dtypes
:
list
[
str
],
shapes
:
list
[
list
[
int
]],
group_name
:
str
=
"weight_update_group"
,
flush_cache
:
bool
=
True
,
):
"""Update weights from distributed source."""
obj
=
UpdateWeightsFromDistributedReqInput
(
name
=
name
,
dtype
=
dtype
,
shape
=
shape
,
names
=
names
,
dtypes
=
dtypes
,
shapes
=
shapes
,
group_name
=
group_name
,
flush_cache
=
flush_cache
,
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
...
...
python/sglang/srt/managers/io_struct.py
View file @
0626f678
...
...
@@ -752,9 +752,13 @@ class UpdateWeightFromDiskReqOutput:
@
dataclass
class
UpdateWeightsFromDistributedReqInput
:
name
:
str
dtype
:
str
shape
:
List
[
int
]
names
:
List
[
str
]
dtypes
:
List
[
str
]
shapes
:
List
[
List
[
int
]]
# The group name
group_name
:
str
=
"weight_update_group"
# Whether to flush the cache after updating weights
flush_cache
:
bool
=
True
@
dataclass
...
...
python/sglang/srt/managers/scheduler.py
View file @
0626f678
...
...
@@ -2303,8 +2303,9 @@ class Scheduler(
"""Update the online model parameter."""
success
,
message
=
self
.
tp_worker
.
update_weights_from_distributed
(
recv_req
)
if
success
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
if
recv_req
.
flush_cache
:
flush_cache_success
=
self
.
flush_cache
()
assert
flush_cache_success
,
"Cache flush failed after updating weights"
else
:
logger
.
error
(
message
)
return
UpdateWeightsFromDistributedReqOutput
(
success
,
message
)
...
...
python/sglang/srt/managers/tp_worker.py
View file @
0626f678
...
...
@@ -259,7 +259,7 @@ class TpModelWorker:
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
):
success
,
message
=
self
.
model_runner
.
update_weights_from_distributed
(
recv_req
.
name
,
recv_req
.
dtype
,
recv_req
.
shape
recv_req
.
name
s
,
recv_req
.
dtype
s
,
recv_req
.
shape
s
,
recv_req
.
group_name
)
return
success
,
message
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
0626f678
...
...
@@ -225,6 +225,7 @@ class ModelRunner:
self
.
support_pp
=
(
"pp_proxy_tensors"
in
inspect
.
signature
(
self
.
model
.
forward
).
parameters
)
self
.
_model_update_group
=
{}
def
initialize
(
self
,
min_per_gpu_memory
:
float
):
server_args
=
self
.
server_args
...
...
@@ -744,7 +745,7 @@ class ModelRunner:
)
try
:
self
.
_model_update_group
=
init_custom_process_group
(
self
.
_model_update_group
[
group_name
]
=
init_custom_process_group
(
backend
=
backend
,
init_method
=
f
"tcp://
{
master_address
}
:
{
master_port
}
"
,
world_size
=
world_size
,
...
...
@@ -757,7 +758,7 @@ class ModelRunner:
logger
.
error
(
message
)
return
False
,
message
def
update_weights_from_distributed
(
self
,
name
,
dtype
,
shape
):
def
update_weights_from_distributed
(
self
,
name
s
,
dtype
s
,
shape
s
,
group_name
):
"""
Update specific parameter in the model weights online
through `_model_update_group` process group.
...
...
@@ -767,19 +768,34 @@ class ModelRunner:
dtype: the data type of the parameter to be updated.
shape: the shape of the parameter to be updated.
"""
target_dtype
=
(
dtype
if
isinstance
(
dtype
,
torch
.
dtype
)
else
getattr
(
torch
,
dtype
)
)
assert
(
self
.
_model_update_group
is
not
None
),
"model update group must be initialized"
assert
group_name
in
self
.
_model_update_group
,
(
f
"Group
{
group_name
}
not in
{
list
(
self
.
_model_update_group
.
keys
())
}
. "
"Please call `init_weights_update_group` first."
)
try
:
weights
=
torch
.
empty
(
shape
,
dtype
=
target_dtype
,
device
=
self
.
device
)
torch
.
distributed
.
broadcast
(
weights
,
src
=
0
,
group
=
self
.
_model_update_group
)
self
.
model
.
load_weights
([(
name
,
weights
)])
return
True
,
f
"Succeeded to update parameter
{
name
}
online."
weights
=
[]
handles
=
[]
for
name
,
dtype
,
shape
in
zip
(
names
,
dtypes
,
shapes
):
target_dtype
=
(
dtype
if
isinstance
(
dtype
,
torch
.
dtype
)
else
getattr
(
torch
,
dtype
)
)
weight
=
torch
.
empty
(
shape
,
dtype
=
target_dtype
,
device
=
self
.
device
)
handles
.
append
(
torch
.
distributed
.
broadcast
(
weight
,
src
=
0
,
group
=
self
.
_model_update_group
[
group_name
],
async_op
=
True
,
)
)
weights
.
append
((
name
,
weight
))
for
handle
in
handles
:
handle
.
wait
()
self
.
model
.
load_weights
(
weights
)
return
True
,
f
"Succeeded to update parameter online."
except
Exception
as
e
:
error_msg
=
(
...
...
test/srt/test_update_weights_from_distributed.py
View file @
0626f678
...
...
@@ -294,22 +294,27 @@ def init_process_sgl(
update_parameters
.
remove
(
"lm_head.weight"
)
# Get weights from the training engine and update the inference engine.
for
parameter_name
in
update_parameters
:
if
backend
==
"Engine"
:
engine
.
update_weights_from_distributed
(
parameter_name
,
dtype
=
torch
.
bfloat16
,
shape
=
state_dict_key_to_shape
[
parameter_name
],
)
else
:
requests
.
post
(
f
"
{
url
}
/update_weights_from_distributed"
,
json
=
{
"name"
:
parameter_name
,
"dtype"
:
"bfloat16"
,
"shape"
:
state_dict_key_to_shape
[
parameter_name
],
},
)
names
=
[
parameter_name
for
parameter_name
in
update_parameters
]
dtypes
=
[
torch
.
bfloat16
if
backend
==
"Engine"
else
"bfloat16"
]
*
len
(
names
)
shapes
=
[
state_dict_key_to_shape
[
parameter_name
]
for
parameter_name
in
names
]
if
backend
==
"Engine"
:
engine
.
update_weights_from_distributed
(
names
,
dtypes
=
dtypes
,
shapes
=
shapes
,
group_name
=
"test_parameter_update_group"
,
)
else
:
requests
.
post
(
f
"
{
url
}
/update_weights_from_distributed"
,
json
=
{
"names"
:
names
,
"dtypes"
:
dtypes
,
"shapes"
:
shapes
,
"group_name"
:
"test_parameter_update_group"
,
},
)
torch
.
cuda
.
synchronize
()
time_end_update
=
time
.
perf_counter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment