Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
2fce449b
Unverified
Commit
2fce449b
authored
Oct 23, 2024
by
Ying Sheng
Committed by
GitHub
Oct 23, 2024
Browse files
[API] add get memory pool size (#1760)
Co-authored-by:
Byron Hsu
<
byronhsu1230@gmail.com
>
parent
ad4125d1
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
50 additions
and
0 deletions
+50
-0
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+4
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+10
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+6
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+14
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+12
-0
test/srt/test_srt_endpoint.py
test/srt/test_srt_endpoint.py
+4
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
2fce449b
...
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -27,6 +27,7 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut
,
BatchEmbeddingOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
GetMemPoolSizeReqOutput
,
UpdateWeightReqOutput
,
UpdateWeightReqOutput
,
)
)
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
,
FINISH_MATCHED_TOKEN
from
sglang.srt.managers.schedule_batch
import
FINISH_MATCHED_STR
,
FINISH_MATCHED_TOKEN
...
@@ -111,6 +112,9 @@ class DetokenizerManager:
...
@@ -111,6 +112,9 @@ class DetokenizerManager:
# If it is a weight update request, no detokenization is needed.
# If it is a weight update request, no detokenization is needed.
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
continue
elif
isinstance
(
recv_obj
,
GetMemPoolSizeReqOutput
):
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
continue
elif
self
.
tokenizer
is
None
:
elif
self
.
tokenizer
is
None
:
# If the tokenizer is skipped, no detokenization is needed
# If the tokenizer is skipped, no detokenization is needed
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
self
.
send_to_tokenizer
.
send_pyobj
(
recv_obj
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
2fce449b
...
@@ -353,3 +353,13 @@ class AbortReq:
...
@@ -353,3 +353,13 @@ class AbortReq:
class
ProfileReq
(
Enum
):
class
ProfileReq
(
Enum
):
START_PROFILE
=
1
START_PROFILE
=
1
STOP_PROFILE
=
2
STOP_PROFILE
=
2
@
dataclass
class
GetMemPoolSizeReq
:
pass
@
dataclass
class
GetMemPoolSizeReqOutput
:
size
:
int
python/sglang/srt/managers/scheduler.py
View file @
2fce449b
...
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -38,6 +38,8 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
FlushCacheReq
,
FlushCacheReq
,
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
ProfileReq
,
ProfileReq
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -363,6 +365,10 @@ class Scheduler:
...
@@ -363,6 +365,10 @@ class Scheduler:
self
.
start_profile
()
self
.
start_profile
()
else
:
else
:
self
.
stop_profile
()
self
.
stop_profile
()
elif
isinstance
(
recv_req
,
GetMemPoolSizeReq
):
self
.
send_to_detokenizer
.
send_pyobj
(
GetMemPoolSizeReqOutput
(
self
.
max_total_num_tokens
)
)
else
:
else
:
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
raise
ValueError
(
f
"Invalid request:
{
recv_req
}
"
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
2fce449b
...
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
...
@@ -46,6 +46,8 @@ from sglang.srt.managers.io_struct import (
EmbeddingReqInput
,
EmbeddingReqInput
,
FlushCacheReq
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
ProfileReq
,
ProfileReq
,
RewardReqInput
,
RewardReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
...
@@ -531,6 +533,15 @@ class TokenizerManager:
...
@@ -531,6 +533,15 @@ class TokenizerManager:
req
=
ProfileReq
.
STOP_PROFILE
req
=
ProfileReq
.
STOP_PROFILE
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
send_to_scheduler
.
send_pyobj
(
req
)
async
def
get_memory_pool_size
(
self
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
req
=
GetMemPoolSizeReq
()
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
mem_pool_size
=
asyncio
.
Future
()
return
await
self
.
mem_pool_size
async
def
update_weights
(
async
def
update_weights
(
self
,
obj
:
UpdateWeightReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
self
,
obj
:
UpdateWeightReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
):
...
@@ -590,6 +601,9 @@ class TokenizerManager:
...
@@ -590,6 +601,9 @@ class TokenizerManager:
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
self
.
model_update_result
.
set_result
(
recv_obj
)
self
.
model_update_result
.
set_result
(
recv_obj
)
continue
continue
elif
isinstance
(
recv_obj
,
GetMemPoolSizeReqOutput
):
self
.
mem_pool_size
.
set_result
(
recv_obj
)
continue
assert
isinstance
(
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
...
...
python/sglang/srt/server.py
View file @
2fce449b
...
@@ -172,6 +172,18 @@ async def stop_profile():
...
@@ -172,6 +172,18 @@ async def stop_profile():
)
)
@
app
.
api_route
(
"/get_memory_pool_size"
,
methods
=
[
"GET"
,
"POST"
])
async
def
get_memory_pool_size
():
"""Get the memory pool size in number of tokens"""
try
:
ret
=
await
tokenizer_manager
.
get_memory_pool_size
()
return
ret
.
size
except
Exception
as
e
:
return
JSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
post
(
"/update_weights"
)
@
app
.
post
(
"/update_weights"
)
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
async
def
update_weights
(
obj
:
UpdateWeightReqInput
,
request
:
Request
):
"""Update the weights inplace without re-launching the server."""
"""Update the weights inplace without re-launching the server."""
...
...
test/srt/test_srt_endpoint.py
View file @
2fce449b
...
@@ -119,6 +119,10 @@ class TestSRTEndpoint(unittest.TestCase):
...
@@ -119,6 +119,10 @@ class TestSRTEndpoint(unittest.TestCase):
[
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"output_token_logprobs"
]]
[
x
[
-
1
]
for
x
in
res
[
"meta_info"
][
"output_token_logprobs"
]]
)
)
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/get_memory_pool_size"
)
assert
isinstance
(
response
.
json
(),
int
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment