Unverified Commit cef91b1e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

[PD] Add control to slow down a server (#5572)

parent 6450c122
...@@ -62,6 +62,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -62,6 +62,7 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
SeparateReasoningReqInput, SeparateReasoningReqInput,
SetInternalStateReq, SetInternalStateReq,
SlowDownReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqInput,
...@@ -494,6 +495,19 @@ async def resume_memory_occupation( ...@@ -494,6 +495,19 @@ async def resume_memory_occupation(
return _create_error_response(e) return _create_error_response(e)
@app.api_route("/slow_down", methods=["GET", "POST"])
async def slow_down(obj: SlowDownReqInput, request: Request):
"""Slow down the system deliberately. Only for testing. Example scenario:
when we want to test performance of D in large-scale PD disaggregation and have no enough nodes for P,
we can use this to slow down D to let it have enough running sequences, and then disable slowdown
to let it run in full batch size.
"""
try:
await _global_state.tokenizer_manager.slow_down(obj, request)
except Exception as e:
return _create_error_response(e)
@app.api_route("/open_session", methods=["GET", "POST"]) @app.api_route("/open_session", methods=["GET", "POST"])
async def open_session(obj: OpenSessionReqInput, request: Request): async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id.""" """Open a session, and return its unique session id."""
......
...@@ -790,6 +790,16 @@ class ResumeMemoryOccupationReqOutput: ...@@ -790,6 +790,16 @@ class ResumeMemoryOccupationReqOutput:
pass pass
@dataclass
class SlowDownReqInput:
forward_sleep_time: Optional[float]
@dataclass
class SlowDownReqOutput:
pass
@dataclass @dataclass
class AbortReq: class AbortReq:
# The request id # The request id
......
...@@ -87,6 +87,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -87,6 +87,8 @@ from sglang.srt.managers.io_struct import (
RpcReqOutput, RpcReqOutput,
SetInternalStateReq, SetInternalStateReq,
SetInternalStateReqOutput, SetInternalStateReqOutput,
SlowDownReqInput,
SlowDownReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -417,6 +419,8 @@ class Scheduler( ...@@ -417,6 +419,8 @@ class Scheduler(
self.profiler_id: Optional[str] = None self.profiler_id: Optional[str] = None
self.profiler_target_forward_ct: Optional[int] = None self.profiler_target_forward_ct: Optional[int] = None
self.forward_sleep_time = None
# Init metrics stats # Init metrics stats
self.init_metrics() self.init_metrics()
...@@ -439,6 +443,7 @@ class Scheduler( ...@@ -439,6 +443,7 @@ class Scheduler(
(GetWeightsByNameReqInput, self.get_weights_by_name), (GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
(SlowDownReqInput, self.slow_down),
(ProfileReq, self.profile), (ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state), (GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state), (SetInternalStateReq, self.set_internal_state),
...@@ -1526,6 +1531,10 @@ class Scheduler( ...@@ -1526,6 +1531,10 @@ class Scheduler(
): ):
self.stop_profile() self.stop_profile()
if self.forward_sleep_time is not None:
logger.info(f"Scheduler.run_batch sleep {self.forward_sleep_time}s")
time.sleep(self.forward_sleep_time)
# Run forward # Run forward
if self.is_generation: if self.is_generation:
if self.spec_algorithm.is_none(): if self.spec_algorithm.is_none():
...@@ -2001,6 +2010,13 @@ class Scheduler( ...@@ -2001,6 +2010,13 @@ class Scheduler(
del self.stashed_model_static_state del self.stashed_model_static_state
return ResumeMemoryOccupationReqOutput() return ResumeMemoryOccupationReqOutput()
def slow_down(self, recv_req: SlowDownReqInput):
t = recv_req.forward_sleep_time
if t is not None and t <= 0:
t = None
self.forward_sleep_time = t
return SlowDownReqOutput()
def profile(self, recv_req: ProfileReq): def profile(self, recv_req: ProfileReq):
if recv_req.type == ProfileReqType.START_PROFILE: if recv_req.type == ProfileReqType.START_PROFILE:
return self.start_profile( return self.start_profile(
......
...@@ -90,6 +90,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -90,6 +90,8 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput, ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
SessionParams, SessionParams,
SlowDownReqInput,
SlowDownReqOutput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
UpdateWeightFromDiskReqInput, UpdateWeightFromDiskReqInput,
...@@ -259,6 +261,9 @@ class TokenizerManager: ...@@ -259,6 +261,9 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator( self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.slow_down_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.flush_cache_communicator = _Communicator( self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -312,6 +317,10 @@ class TokenizerManager: ...@@ -312,6 +317,10 @@ class TokenizerManager:
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv, self.resume_memory_occupation_communicator.handle_recv,
), ),
(
SlowDownReqOutput,
self.slow_down_communicator.handle_recv,
),
( (
FlushCacheReqOutput, FlushCacheReqOutput,
self.flush_cache_communicator.handle_recv, self.flush_cache_communicator.handle_recv,
...@@ -870,6 +879,14 @@ class TokenizerManager: ...@@ -870,6 +879,14 @@ class TokenizerManager:
self.auto_create_handle_loop() self.auto_create_handle_loop()
await self.resume_memory_occupation_communicator(obj) await self.resume_memory_occupation_communicator(obj)
async def slow_down(
self,
obj: SlowDownReqInput,
request: Optional[fastapi.Request] = None,
):
self.auto_create_handle_loop()
await self.slow_down_communicator(obj)
async def open_session( async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
): ):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment