Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
93f75778
Unverified
Commit
93f75778
authored
Sep 19, 2025
by
penguin_wwy
Committed by
GitHub
Sep 19, 2025
Browse files
[RL] Add destroy process group api (#9979)
parent
4039c626
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
109 additions
and
0 deletions
+109
-0
python/sglang/srt/entrypoints/engine.py
python/sglang/srt/entrypoints/engine.py
+14
-0
python/sglang/srt/entrypoints/http_server.py
python/sglang/srt/entrypoints/http_server.py
+15
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+11
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+2
-0
python/sglang/srt/managers/scheduler_update_weights_mixin.py
python/sglang/srt/managers/scheduler_update_weights_mixin.py
+7
-0
python/sglang/srt/managers/tokenizer_communicator_mixin.py
python/sglang/srt/managers/tokenizer_communicator_mixin.py
+21
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+7
-0
python/sglang/srt/managers/tp_worker_overlap_thread.py
python/sglang/srt/managers/tp_worker_overlap_thread.py
+5
-0
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+13
-0
test/srt/rl/test_update_weights_from_distributed.py
test/srt/rl/test_update_weights_from_distributed.py
+14
-0
No files found.
python/sglang/srt/entrypoints/engine.py
View file @
93f75778
...
@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
...
@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
)
)
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
DestroyWeightsUpdateGroupReqInput
,
EmbeddingReqInput
,
EmbeddingReqInput
,
GenerateReqInput
,
GenerateReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
...
@@ -433,6 +434,19 @@ class Engine(EngineBase):
...
@@ -433,6 +434,19 @@ class Engine(EngineBase):
self
.
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
self
.
tokenizer_manager
.
init_weights_update_group
(
obj
,
None
)
)
)
def
destroy_weights_update_group
(
self
,
group_name
:
str
,
):
"""Destroy parameter update group."""
obj
=
DestroyWeightsUpdateGroupReqInput
(
group_name
=
group_name
,
)
loop
=
asyncio
.
get_event_loop
()
return
loop
.
run_until_complete
(
self
.
tokenizer_manager
.
destroy_weights_update_group
(
obj
,
None
)
)
def
update_weights_from_distributed
(
def
update_weights_from_distributed
(
self
,
self
,
names
:
list
[
str
],
names
:
list
[
str
],
...
...
python/sglang/srt/entrypoints/http_server.py
View file @
93f75778
...
@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import (
AbortReq
,
AbortReq
,
CloseSessionReqInput
,
CloseSessionReqInput
,
ConfigureLoggingReq
,
ConfigureLoggingReq
,
DestroyWeightsUpdateGroupReqInput
,
EmbeddingReqInput
,
EmbeddingReqInput
,
GenerateReqInput
,
GenerateReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
...
@@ -729,6 +730,20 @@ async def init_weights_update_group(
...
@@ -729,6 +730,20 @@ async def init_weights_update_group(
return
ORJSONResponse
(
content
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
return
ORJSONResponse
(
content
,
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/destroy_weights_update_group"
)
async
def
destroy_weights_update_group
(
obj
:
DestroyWeightsUpdateGroupReqInput
,
request
:
Request
):
"""Destroy the parameter update group."""
success
,
message
=
(
await
_global_state
.
tokenizer_manager
.
destroy_weights_update_group
(
obj
,
request
)
)
content
=
{
"success"
:
success
,
"message"
:
message
}
return
ORJSONResponse
(
content
,
status_code
=
200
if
success
else
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/update_weights_from_tensor"
)
@
app
.
post
(
"/update_weights_from_tensor"
)
async
def
update_weights_from_tensor
(
async
def
update_weights_from_tensor
(
obj
:
UpdateWeightsFromTensorReqInput
,
request
:
Request
obj
:
UpdateWeightsFromTensorReqInput
,
request
:
Request
...
...
python/sglang/srt/managers/io_struct.py
View file @
93f75778
...
@@ -1094,6 +1094,17 @@ class InitWeightsUpdateGroupReqOutput:
...
@@ -1094,6 +1094,17 @@ class InitWeightsUpdateGroupReqOutput:
message
:
str
message
:
str
@
dataclass
class
DestroyWeightsUpdateGroupReqInput
:
group_name
:
str
=
"weight_update_group"
@
dataclass
class
DestroyWeightsUpdateGroupReqOutput
:
success
:
bool
message
:
str
@
dataclass
@
dataclass
class
UpdateWeightVersionReqInput
:
class
UpdateWeightVersionReqInput
:
# The new weight version
# The new weight version
...
...
python/sglang/srt/managers/scheduler.py
View file @
93f75778
...
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput
,
ClearHiCacheReqInput
,
ClearHiCacheReqOutput
,
ClearHiCacheReqOutput
,
CloseSessionReqInput
,
CloseSessionReqInput
,
DestroyWeightsUpdateGroupReqInput
,
ExpertDistributionReq
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqOutput
,
FlushCacheReqInput
,
FlushCacheReqInput
,
...
@@ -566,6 +567,7 @@ class Scheduler(
...
@@ -566,6 +567,7 @@ class Scheduler(
(
CloseSessionReqInput
,
self
.
close_session
),
(
CloseSessionReqInput
,
self
.
close_session
),
(
UpdateWeightFromDiskReqInput
,
self
.
update_weights_from_disk
),
(
UpdateWeightFromDiskReqInput
,
self
.
update_weights_from_disk
),
(
InitWeightsUpdateGroupReqInput
,
self
.
init_weights_update_group
),
(
InitWeightsUpdateGroupReqInput
,
self
.
init_weights_update_group
),
(
DestroyWeightsUpdateGroupReqInput
,
self
.
destroy_weights_update_group
),
(
(
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
self
.
init_weights_send_group_for_remote_instance
,
self
.
init_weights_send_group_for_remote_instance
,
...
...
python/sglang/srt/managers/scheduler_update_weights_mixin.py
View file @
93f75778
...
@@ -5,6 +5,8 @@ import torch
...
@@ -5,6 +5,8 @@ import torch
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
,
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.constants
import
GPU_MEMORY_TYPE_KV_CACHE
,
GPU_MEMORY_TYPE_WEIGHTS
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
DestroyWeightsUpdateGroupReqInput
,
DestroyWeightsUpdateGroupReqOutput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqOutput
,
GetWeightsByNameReqOutput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
...
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
...
@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
success
,
message
=
self
.
tp_worker
.
init_weights_update_group
(
recv_req
)
return
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
return
InitWeightsUpdateGroupReqOutput
(
success
,
message
)
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
"""Destroy the online model parameter update group."""
success
,
message
=
self
.
tp_worker
.
destroy_weights_update_group
(
recv_req
)
return
DestroyWeightsUpdateGroupReqOutput
(
success
,
message
)
def
update_weights_from_distributed
(
def
update_weights_from_distributed
(
self
,
self
,
recv_req
:
UpdateWeightsFromDistributedReqInput
,
recv_req
:
UpdateWeightsFromDistributedReqInput
,
...
...
python/sglang/srt/managers/tokenizer_communicator_mixin.py
View file @
93f75778
...
@@ -24,6 +24,8 @@ import zmq
...
@@ -24,6 +24,8 @@ import zmq
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
ClearHiCacheReqInput
,
ClearHiCacheReqInput
,
ClearHiCacheReqOutput
,
ClearHiCacheReqOutput
,
DestroyWeightsUpdateGroupReqInput
,
DestroyWeightsUpdateGroupReqOutput
,
ExpertDistributionReq
,
ExpertDistributionReq
,
ExpertDistributionReqOutput
,
ExpertDistributionReqOutput
,
FlushCacheReqInput
,
FlushCacheReqInput
,
...
@@ -149,6 +151,9 @@ class TokenizerCommunicatorMixin:
...
@@ -149,6 +151,9 @@ class TokenizerCommunicatorMixin:
self
.
init_weights_update_group_communicator
=
_Communicator
(
self
.
init_weights_update_group_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
self
.
destroy_weights_update_group_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
)
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
update_weights_from_distributed_communicator
=
_Communicator
(
self
.
send_to_scheduler
,
server_args
.
dp_size
self
.
send_to_scheduler
,
server_args
.
dp_size
)
)
...
@@ -207,6 +212,10 @@ class TokenizerCommunicatorMixin:
...
@@ -207,6 +212,10 @@ class TokenizerCommunicatorMixin:
InitWeightsUpdateGroupReqOutput
,
InitWeightsUpdateGroupReqOutput
,
self
.
init_weights_update_group_communicator
.
handle_recv
,
self
.
init_weights_update_group_communicator
.
handle_recv
,
),
),
(
DestroyWeightsUpdateGroupReqOutput
,
self
.
destroy_weights_update_group_communicator
.
handle_recv
,
),
(
(
UpdateWeightsFromDistributedReqOutput
,
UpdateWeightsFromDistributedReqOutput
,
self
.
update_weights_from_distributed_communicator
.
handle_recv
,
self
.
update_weights_from_distributed_communicator
.
handle_recv
,
...
@@ -345,6 +354,18 @@ class TokenizerCommunicatorMixin:
...
@@ -345,6 +354,18 @@ class TokenizerCommunicatorMixin:
result
=
(
await
self
.
init_weights_update_group_communicator
(
obj
))[
0
]
result
=
(
await
self
.
init_weights_update_group_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
return
result
.
success
,
result
.
message
async
def
destroy_weights_update_group
(
self
,
obj
:
DestroyWeightsUpdateGroupReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
,
)
->
Tuple
[
bool
,
str
]:
self
.
auto_create_handle_loop
()
assert
(
self
.
server_args
.
dp_size
==
1
),
"dp_size must be 1 for destroy parameter update group"
result
=
(
await
self
.
destroy_weights_update_group_communicator
(
obj
))[
0
]
return
result
.
success
,
result
.
message
async
def
update_weights_from_distributed
(
async
def
update_weights_from_distributed
(
self
:
TokenizerManager
,
self
:
TokenizerManager
,
obj
:
UpdateWeightsFromDistributedReqInput
,
obj
:
UpdateWeightsFromDistributedReqInput
,
...
...
python/sglang/srt/managers/tp_worker.py
View file @
93f75778
...
@@ -29,6 +29,7 @@ from sglang.srt.hf_transformers_utils import (
...
@@ -29,6 +29,7 @@ from sglang.srt.hf_transformers_utils import (
)
)
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitsProcessorOutput
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
DestroyWeightsUpdateGroupReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
...
@@ -304,6 +305,12 @@ class TpModelWorker:
...
@@ -304,6 +305,12 @@ class TpModelWorker:
)
)
return
success
,
message
return
success
,
message
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
model_runner
.
destroy_weights_update_group
(
recv_req
.
group_name
,
)
return
success
,
message
def
init_weights_send_group_for_remote_instance
(
def
init_weights_send_group_for_remote_instance
(
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
):
):
...
...
python/sglang/srt/managers/tp_worker_overlap_thread.py
View file @
93f75778
...
@@ -25,6 +25,7 @@ import psutil
...
@@ -25,6 +25,7 @@ import psutil
import
torch
import
torch
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
DestroyWeightsUpdateGroupReqInput
,
GetWeightsByNameReqInput
,
GetWeightsByNameReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsSendGroupForRemoteInstanceReqInput
,
InitWeightsUpdateGroupReqInput
,
InitWeightsUpdateGroupReqInput
,
...
@@ -278,6 +279,10 @@ class TpModelWorkerClient:
...
@@ -278,6 +279,10 @@ class TpModelWorkerClient:
success
,
message
=
self
.
worker
.
init_weights_update_group
(
recv_req
)
success
,
message
=
self
.
worker
.
init_weights_update_group
(
recv_req
)
return
success
,
message
return
success
,
message
def
destroy_weights_update_group
(
self
,
recv_req
:
DestroyWeightsUpdateGroupReqInput
):
success
,
message
=
self
.
worker
.
destroy_weights_update_group
(
recv_req
)
return
success
,
message
def
init_weights_send_group_for_remote_instance
(
def
init_weights_send_group_for_remote_instance
(
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
self
,
recv_req
:
InitWeightsSendGroupForRemoteInstanceReqInput
):
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
93f75778
...
@@ -1025,6 +1025,19 @@ class ModelRunner:
...
@@ -1025,6 +1025,19 @@ class ModelRunner:
logger
.
error
(
message
)
logger
.
error
(
message
)
return
False
,
message
return
False
,
message
def
destroy_weights_update_group
(
self
,
group_name
):
try
:
if
group_name
in
self
.
_model_update_group
:
pg
=
self
.
_model_update_group
.
pop
(
group_name
)
torch
.
distributed
.
destroy_process_group
(
pg
)
return
True
,
"Succeeded to destroy custom process group."
else
:
return
False
,
"The group to be destroyed does not exist."
except
Exception
as
e
:
message
=
f
"Failed to destroy custom process group:
{
e
}
."
logger
.
error
(
message
)
return
False
,
message
def
update_weights_from_distributed
(
self
,
names
,
dtypes
,
shapes
,
group_name
):
def
update_weights_from_distributed
(
self
,
names
,
dtypes
,
shapes
,
group_name
):
"""
"""
Update specific parameter in the model weights online
Update specific parameter in the model weights online
...
...
test/srt/rl/test_update_weights_from_distributed.py
View file @
93f75778
...
@@ -344,6 +344,20 @@ def init_process_sgl(
...
@@ -344,6 +344,20 @@ def init_process_sgl(
)
)
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_base_params"
,
base_params
))
param_queue
.
put
((
f
"sgl_dp_
{
rank
}
_base_params"
,
base_params
))
if
backend
==
"Engine"
:
success
,
_
=
engine
.
destroy_weights_update_group
(
group_name
=
"test_parameter_update_group"
,
)
assert
success
is
True
else
:
response
=
requests
.
post
(
f
"
{
url
}
/destroy_weights_update_group"
,
json
=
{
"group_name"
:
"test_parameter_update_group"
,
},
)
assert
response
.
status_code
==
200
# Shutdown the engine or terminate the server process.
# Shutdown the engine or terminate the server process.
if
backend
==
"Engine"
:
if
backend
==
"Engine"
:
engine
.
shutdown
()
engine
.
shutdown
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment