Unverified Commit 93f75778 authored by penguin_wwy's avatar penguin_wwy Committed by GitHub
Browse files

[RL] Add destroy process group api (#9979)

parent 4039c626
...@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import ( ...@@ -47,6 +47,7 @@ from sglang.srt.managers.data_parallel_controller import (
) )
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
...@@ -433,6 +434,19 @@ class Engine(EngineBase): ...@@ -433,6 +434,19 @@ class Engine(EngineBase):
self.tokenizer_manager.init_weights_update_group(obj, None) self.tokenizer_manager.init_weights_update_group(obj, None)
) )
def destroy_weights_update_group(
self,
group_name: str,
):
"""Destroy parameter update group."""
obj = DestroyWeightsUpdateGroupReqInput(
group_name=group_name,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.destroy_weights_update_group(obj, None)
)
def update_weights_from_distributed( def update_weights_from_distributed(
self, self,
names: list[str], names: list[str],
......
...@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -70,6 +70,7 @@ from sglang.srt.managers.io_struct import (
AbortReq, AbortReq,
CloseSessionReqInput, CloseSessionReqInput,
ConfigureLoggingReq, ConfigureLoggingReq,
DestroyWeightsUpdateGroupReqInput,
EmbeddingReqInput, EmbeddingReqInput,
GenerateReqInput, GenerateReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
...@@ -729,6 +730,20 @@ async def init_weights_update_group( ...@@ -729,6 +730,20 @@ async def init_weights_update_group(
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/destroy_weights_update_group")
async def destroy_weights_update_group(
obj: DestroyWeightsUpdateGroupReqInput, request: Request
):
"""Destroy the parameter update group."""
success, message = (
await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)
)
content = {"success": success, "message": message}
return ORJSONResponse(
content, status_code=200 if success else HTTPStatus.BAD_REQUEST
)
@app.post("/update_weights_from_tensor") @app.post("/update_weights_from_tensor")
async def update_weights_from_tensor( async def update_weights_from_tensor(
obj: UpdateWeightsFromTensorReqInput, request: Request obj: UpdateWeightsFromTensorReqInput, request: Request
......
...@@ -1094,6 +1094,17 @@ class InitWeightsUpdateGroupReqOutput: ...@@ -1094,6 +1094,17 @@ class InitWeightsUpdateGroupReqOutput:
message: str message: str
@dataclass
class DestroyWeightsUpdateGroupReqInput:
group_name: str = "weight_update_group"
@dataclass
class DestroyWeightsUpdateGroupReqOutput:
success: bool
message: str
@dataclass @dataclass
class UpdateWeightVersionReqInput: class UpdateWeightVersionReqInput:
# The new weight version # The new weight version
......
...@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -72,6 +72,7 @@ from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput, ClearHiCacheReqInput,
ClearHiCacheReqOutput, ClearHiCacheReqOutput,
CloseSessionReqInput, CloseSessionReqInput,
DestroyWeightsUpdateGroupReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
FlushCacheReqInput, FlushCacheReqInput,
...@@ -566,6 +567,7 @@ class Scheduler( ...@@ -566,6 +567,7 @@ class Scheduler(
(CloseSessionReqInput, self.close_session), (CloseSessionReqInput, self.close_session),
(UpdateWeightFromDiskReqInput, self.update_weights_from_disk), (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
(InitWeightsUpdateGroupReqInput, self.init_weights_update_group), (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
(DestroyWeightsUpdateGroupReqInput, self.destroy_weights_update_group),
( (
InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsSendGroupForRemoteInstanceReqInput,
self.init_weights_send_group_for_remote_instance, self.init_weights_send_group_for_remote_instance,
......
...@@ -5,6 +5,8 @@ import torch ...@@ -5,6 +5,8 @@ import torch
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
GetWeightsByNameReqOutput, GetWeightsByNameReqOutput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
...@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin: ...@@ -41,6 +43,11 @@ class SchedulerUpdateWeightsMixin:
success, message = self.tp_worker.init_weights_update_group(recv_req) success, message = self.tp_worker.init_weights_update_group(recv_req)
return InitWeightsUpdateGroupReqOutput(success, message) return InitWeightsUpdateGroupReqOutput(success, message)
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
"""Destroy the online model parameter update group."""
success, message = self.tp_worker.destroy_weights_update_group(recv_req)
return DestroyWeightsUpdateGroupReqOutput(success, message)
def update_weights_from_distributed( def update_weights_from_distributed(
self, self,
recv_req: UpdateWeightsFromDistributedReqInput, recv_req: UpdateWeightsFromDistributedReqInput,
......
...@@ -24,6 +24,8 @@ import zmq ...@@ -24,6 +24,8 @@ import zmq
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
ClearHiCacheReqInput, ClearHiCacheReqInput,
ClearHiCacheReqOutput, ClearHiCacheReqOutput,
DestroyWeightsUpdateGroupReqInput,
DestroyWeightsUpdateGroupReqOutput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
FlushCacheReqInput, FlushCacheReqInput,
...@@ -149,6 +151,9 @@ class TokenizerCommunicatorMixin: ...@@ -149,6 +151,9 @@ class TokenizerCommunicatorMixin:
self.init_weights_update_group_communicator = _Communicator( self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.destroy_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_distributed_communicator = _Communicator( self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -207,6 +212,10 @@ class TokenizerCommunicatorMixin: ...@@ -207,6 +212,10 @@ class TokenizerCommunicatorMixin:
InitWeightsUpdateGroupReqOutput, InitWeightsUpdateGroupReqOutput,
self.init_weights_update_group_communicator.handle_recv, self.init_weights_update_group_communicator.handle_recv,
), ),
(
DestroyWeightsUpdateGroupReqOutput,
self.destroy_weights_update_group_communicator.handle_recv,
),
( (
UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromDistributedReqOutput,
self.update_weights_from_distributed_communicator.handle_recv, self.update_weights_from_distributed_communicator.handle_recv,
...@@ -345,6 +354,18 @@ class TokenizerCommunicatorMixin: ...@@ -345,6 +354,18 @@ class TokenizerCommunicatorMixin:
result = (await self.init_weights_update_group_communicator(obj))[0] result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message return result.success, result.message
async def destroy_weights_update_group(
self,
obj: DestroyWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for destroy parameter update group"
result = (await self.destroy_weights_update_group_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_distributed( async def update_weights_from_distributed(
self: TokenizerManager, self: TokenizerManager,
obj: UpdateWeightsFromDistributedReqInput, obj: UpdateWeightsFromDistributedReqInput,
......
...@@ -29,6 +29,7 @@ from sglang.srt.hf_transformers_utils import ( ...@@ -29,6 +29,7 @@ from sglang.srt.hf_transformers_utils import (
) )
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
...@@ -304,6 +305,12 @@ class TpModelWorker: ...@@ -304,6 +305,12 @@ class TpModelWorker:
) )
return success, message return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.model_runner.destroy_weights_update_group(
recv_req.group_name,
)
return success, message
def init_weights_send_group_for_remote_instance( def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
): ):
......
...@@ -25,6 +25,7 @@ import psutil ...@@ -25,6 +25,7 @@ import psutil
import torch import torch
from sglang.srt.managers.io_struct import ( from sglang.srt.managers.io_struct import (
DestroyWeightsUpdateGroupReqInput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
InitWeightsSendGroupForRemoteInstanceReqInput, InitWeightsSendGroupForRemoteInstanceReqInput,
InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqInput,
...@@ -278,6 +279,10 @@ class TpModelWorkerClient: ...@@ -278,6 +279,10 @@ class TpModelWorkerClient:
success, message = self.worker.init_weights_update_group(recv_req) success, message = self.worker.init_weights_update_group(recv_req)
return success, message return success, message
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
success, message = self.worker.destroy_weights_update_group(recv_req)
return success, message
def init_weights_send_group_for_remote_instance( def init_weights_send_group_for_remote_instance(
self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
): ):
......
...@@ -1025,6 +1025,19 @@ class ModelRunner: ...@@ -1025,6 +1025,19 @@ class ModelRunner:
logger.error(message) logger.error(message)
return False, message return False, message
def destroy_weights_update_group(self, group_name):
try:
if group_name in self._model_update_group:
pg = self._model_update_group.pop(group_name)
torch.distributed.destroy_process_group(pg)
return True, "Succeeded to destroy custom process group."
else:
return False, "The group to be destroyed does not exist."
except Exception as e:
message = f"Failed to destroy custom process group: {e}."
logger.error(message)
return False, message
def update_weights_from_distributed(self, names, dtypes, shapes, group_name): def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
""" """
Update specific parameter in the model weights online Update specific parameter in the model weights online
......
...@@ -344,6 +344,20 @@ def init_process_sgl( ...@@ -344,6 +344,20 @@ def init_process_sgl(
) )
param_queue.put((f"sgl_dp_{rank}_base_params", base_params)) param_queue.put((f"sgl_dp_{rank}_base_params", base_params))
if backend == "Engine":
success, _ = engine.destroy_weights_update_group(
group_name="test_parameter_update_group",
)
assert success is True
else:
response = requests.post(
f"{url}/destroy_weights_update_group",
json={
"group_name": "test_parameter_update_group",
},
)
assert response.status_code == 200
# Shutdown the engine or terminate the server process. # Shutdown the engine or terminate the server process.
if backend == "Engine": if backend == "Engine":
engine.shutdown() engine.shutdown()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment