"vscode:/vscode.git/clone" did not exist on "a2d35df2e4d1013f3a58b2726f7b92074d842c78"
Unverified Commit 11951820 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny add Engine.flush_cache API (#5241)

parent 5239d795
...@@ -279,6 +279,10 @@ class Engine(EngineBase): ...@@ -279,6 +279,10 @@ class Engine(EngineBase):
self.shutdown() self.shutdown()
return False return False
def flush_cache(self):
loop = asyncio.get_event_loop()
return loop.run_until_complete(self.tokenizer_manager.flush_cache())
def start_profile(self): def start_profile(self):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
loop.run_until_complete(self.tokenizer_manager.start_profile()) loop.run_until_complete(self.tokenizer_manager.start_profile())
......
...@@ -315,11 +315,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -315,11 +315,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
@app.api_route("/flush_cache", methods=["GET", "POST"]) @app.api_route("/flush_cache", methods=["GET", "POST"])
async def flush_cache(): async def flush_cache():
"""Flush the radix cache.""" """Flush the radix cache."""
_global_state.tokenizer_manager.flush_cache() ret = await _global_state.tokenizer_manager.flush_cache()
return Response( return Response(
content="Cache flushed.\nPlease check backend logs for more details. " content="Cache flushed.\nPlease check backend logs for more details. "
"(When there are running or waiting requests, the operation will not be performed.)\n", "(When there are running or waiting requests, the operation will not be performed.)\n",
status_code=200, status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
) )
......
...@@ -671,10 +671,15 @@ class BatchEmbeddingOut: ...@@ -671,10 +671,15 @@ class BatchEmbeddingOut:
@dataclass @dataclass
class FlushCacheReq: class FlushCacheReqInput:
pass pass
@dataclass
class FlushCacheReqOutput:
success: bool
@dataclass @dataclass
class UpdateWeightFromDiskReqInput: class UpdateWeightFromDiskReqInput:
# The model path with the new weights # The model path with the new weights
......
...@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
CloseSessionReqInput, CloseSessionReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
FlushCacheReq, FlushCacheReqInput,
FlushCacheReqOutput,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
GetWeightsByNameReqInput, GetWeightsByNameReqInput,
...@@ -402,7 +403,7 @@ class Scheduler( ...@@ -402,7 +403,7 @@ class Scheduler(
[ [
(TokenizedGenerateReqInput, self.handle_generate_request), (TokenizedGenerateReqInput, self.handle_generate_request),
(TokenizedEmbeddingReqInput, self.handle_embedding_request), (TokenizedEmbeddingReqInput, self.handle_embedding_request),
(FlushCacheReq, self.flush_cache_wrapped), (FlushCacheReqInput, self.flush_cache_wrapped),
(AbortReq, self.abort_request), (AbortReq, self.abort_request),
(OpenSessionReqInput, self.open_session), (OpenSessionReqInput, self.open_session),
(CloseSessionReqInput, self.close_session), (CloseSessionReqInput, self.close_session),
...@@ -1596,8 +1597,9 @@ class Scheduler( ...@@ -1596,8 +1597,9 @@ class Scheduler(
time.sleep(5) time.sleep(5)
self.parent_process.send_signal(signal.SIGQUIT) self.parent_process.send_signal(signal.SIGQUIT)
def flush_cache_wrapped(self, recv_req: FlushCacheReq): def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
self.flush_cache() success = self.flush_cache()
return FlushCacheReqOutput(success=success)
def flush_cache(self): def flush_cache(self):
"""Flush the memory pool and cache.""" """Flush the memory pool and cache."""
......
...@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
ExpertDistributionReq, ExpertDistributionReq,
ExpertDistributionReqOutput, ExpertDistributionReqOutput,
FlushCacheReq, FlushCacheReqInput,
FlushCacheReqOutput,
GenerateReqInput, GenerateReqInput,
GetInternalStateReq, GetInternalStateReq,
GetInternalStateReqOutput, GetInternalStateReqOutput,
...@@ -264,6 +265,9 @@ class TokenizerManager: ...@@ -264,6 +265,9 @@ class TokenizerManager:
self.resume_memory_occupation_communicator = _Communicator( self.resume_memory_occupation_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
self.flush_cache_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.start_profile_communicator = _Communicator( self.start_profile_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size self.send_to_scheduler, server_args.dp_size
) )
...@@ -314,6 +318,10 @@ class TokenizerManager: ...@@ -314,6 +318,10 @@ class TokenizerManager:
ResumeMemoryOccupationReqOutput, ResumeMemoryOccupationReqOutput,
self.resume_memory_occupation_communicator.handle_recv, self.resume_memory_occupation_communicator.handle_recv,
), ),
(
FlushCacheReqOutput,
self.flush_cache_communicator.handle_recv,
),
( (
ProfileReqOutput, ProfileReqOutput,
self.start_profile_communicator.handle_recv, self.start_profile_communicator.handle_recv,
...@@ -707,9 +715,8 @@ class TokenizerManager: ...@@ -707,9 +715,8 @@ class TokenizerManager:
except StopAsyncIteration: except StopAsyncIteration:
pass pass
def flush_cache(self): async def flush_cache(self) -> FlushCacheReqOutput:
req = FlushCacheReq() return await self.flush_cache_communicator(FlushCacheReqInput())
self.send_to_scheduler.send_pyobj(req)
def abort_request(self, rid: str): def abort_request(self, rid: str):
if rid not in self.rid_to_state: if rid not in self.rid_to_state:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment