Unverified Commit 2fce449b authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[API] add get memory pool size (#1760)


Co-authored-by: default avatarByron Hsu <byronhsu1230@gmail.com>
parent ad4125d1
...@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import ( ...@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchStrOut, BatchStrOut,
BatchTokenIDOut, BatchTokenIDOut,
GetMemPoolSizeReqOutput,
UpdateWeightReqOutput, UpdateWeightReqOutput,
) )
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
...@@ -111,6 +112,9 @@ class DetokenizerManager: ...@@ -111,6 +112,9 @@ class DetokenizerManager:
# If it is a weight update request, no detokenization is needed. # If it is a weight update request, no detokenization is needed.
self.send_to_tokenizer.send_pyobj(recv_obj) self.send_to_tokenizer.send_pyobj(recv_obj)
continue continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.send_to_tokenizer.send_pyobj(recv_obj)
continue
elif self.tokenizer is None: elif self.tokenizer is None:
# If the tokenizer is skipped, no detokenization is needed # If the tokenizer is skipped, no detokenization is needed
self.send_to_tokenizer.send_pyobj(recv_obj) self.send_to_tokenizer.send_pyobj(recv_obj)
......
...@@ -353,3 +353,13 @@ class AbortReq: ...@@ -353,3 +353,13 @@ class AbortReq:
class ProfileReq(Enum): class ProfileReq(Enum):
START_PROFILE = 1 START_PROFILE = 1
STOP_PROFILE = 2 STOP_PROFILE = 2
@dataclass
class GetMemPoolSizeReq:
pass
@dataclass
class GetMemPoolSizeReqOutput:
size: int
...@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut, BatchEmbeddingOut,
BatchTokenIDOut, BatchTokenIDOut,
FlushCacheReq, FlushCacheReq,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
ProfileReq, ProfileReq,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput, TokenizedGenerateReqInput,
...@@ -363,6 +365,10 @@ class Scheduler: ...@@ -363,6 +365,10 @@ class Scheduler:
self.start_profile() self.start_profile()
else: else:
self.stop_profile() self.stop_profile()
elif isinstance(recv_req, GetMemPoolSizeReq):
self.send_to_detokenizer.send_pyobj(
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
)
else: else:
raise ValueError(f"Invalid request: {recv_req}") raise ValueError(f"Invalid request: {recv_req}")
......
...@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import ( ...@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput, EmbeddingReqInput,
FlushCacheReq, FlushCacheReq,
GenerateReqInput, GenerateReqInput,
GetMemPoolSizeReq,
GetMemPoolSizeReqOutput,
ProfileReq, ProfileReq,
RewardReqInput, RewardReqInput,
TokenizedEmbeddingReqInput, TokenizedEmbeddingReqInput,
...@@ -531,6 +533,15 @@ class TokenizerManager: ...@@ -531,6 +533,15 @@ class TokenizerManager:
req = ProfileReq.STOP_PROFILE req = ProfileReq.STOP_PROFILE
self.send_to_scheduler.send_pyobj(req) self.send_to_scheduler.send_pyobj(req)
async def get_memory_pool_size(self):
if self.to_create_loop:
self.create_handle_loop()
req = GetMemPoolSizeReq()
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
return await self.mem_pool_size
async def update_weights( async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
): ):
...@@ -590,6 +601,9 @@ class TokenizerManager: ...@@ -590,6 +601,9 @@ class TokenizerManager:
if isinstance(recv_obj, UpdateWeightReqOutput): if isinstance(recv_obj, UpdateWeightReqOutput):
self.model_update_result.set_result(recv_obj) self.model_update_result.set_result(recv_obj)
continue continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.mem_pool_size.set_result(recv_obj)
continue
assert isinstance( assert isinstance(
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut) recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
......
...@@ -172,6 +172,18 @@ async def stop_profile(): ...@@ -172,6 +172,18 @@ async def stop_profile():
) )
@app.api_route("/get_memory_pool_size", methods=["GET", "POST"])
async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try:
ret = await tokenizer_manager.get_memory_pool_size()
return ret.size
except Exception as e:
return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.post("/update_weights") @app.post("/update_weights")
async def update_weights(obj: UpdateWeightReqInput, request: Request): async def update_weights(obj: UpdateWeightReqInput, request: Request):
"""Update the weights inplace without re-launching the server.""" """Update the weights inplace without re-launching the server."""
......
...@@ -119,6 +119,10 @@ class TestSRTEndpoint(unittest.TestCase): ...@@ -119,6 +119,10 @@ class TestSRTEndpoint(unittest.TestCase):
[x[-1] for x in res["meta_info"]["output_token_logprobs"]] [x[-1] for x in res["meta_info"]["output_token_logprobs"]]
) )
def test_get_memory_pool_size(self):
response = requests.post(self.base_url + "/get_memory_pool_size")
assert isinstance(response.json(), int)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment