Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
680cad20
Unverified
Commit
680cad20
authored
Oct 28, 2024
by
Byron Hsu
Committed by
GitHub
Oct 28, 2024
Browse files
fix get_memory_pool_size deadlock for DP (#1830)
parent
0a24eb85
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
34 additions
and
5 deletions
+34
-5
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+23
-4
python/sglang/srt/server.py
python/sglang/srt/server.py
+2
-1
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+9
-0
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
680cad20
...
...
@@ -539,9 +539,22 @@ class TokenizerManager:
self
.
create_handle_loop
()
req
=
GetMemPoolSizeReq
()
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
mem_pool_size
=
asyncio
.
Future
()
return
await
self
.
mem_pool_size
ret
=
None
if
self
.
server_args
.
dp_size
==
1
:
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
mem_pool_size
=
asyncio
.
Future
()
res
=
await
self
.
mem_pool_size
ret
=
res
.
size
else
:
# self.server_args.dp_size > 1
self
.
send_to_scheduler
.
send_pyobj
(
req
)
self
.
mem_pool_size
=
asyncio
.
Future
()
self
.
mem_pool_size_tmp
=
[]
res
=
await
self
.
mem_pool_size
ret
=
[
r
.
size
for
r
in
res
]
return
ret
async
def
update_weights
(
self
,
obj
:
UpdateWeightReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
...
...
@@ -634,7 +647,13 @@ class TokenizerManager:
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
continue
elif
isinstance
(
recv_obj
,
GetMemPoolSizeReqOutput
):
self
.
mem_pool_size
.
set_result
(
recv_obj
)
if
self
.
server_args
.
dp_size
==
1
:
self
.
mem_pool_size
.
set_result
(
recv_obj
)
else
:
# self.sever_args.dp_size > 1
self
.
mem_pool_size_tmp
.
append
(
recv_obj
)
# set future if the all results are received
if
len
(
self
.
mem_pool_size_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
mem_pool_size
.
set_result
(
self
.
mem_pool_size_tmp
)
continue
assert
isinstance
(
...
...
python/sglang/srt/server.py
View file @
680cad20
...
...
@@ -177,7 +177,8 @@ async def get_memory_pool_size():
"""Get the memory pool size in number of tokens"""
try
:
ret
=
await
tokenizer_manager
.
get_memory_pool_size
()
return
ret
.
size
return
ret
except
Exception
as
e
:
return
JSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
...
...
test/srt/test_data_parallelism.py
View file @
680cad20
...
...
@@ -62,6 +62,15 @@ class TestDataParallelism(unittest.TestCase):
# check if the response is 200
assert
response
.
status_code
==
200
def
test_get_memory_pool_size
(
self
):
response
=
requests
.
get
(
self
.
base_url
+
"/get_memory_pool_size"
)
assert
response
.
status_code
==
200
time
.
sleep
(
5
)
response
=
requests
.
get
(
self
.
base_url
+
"/get_memory_pool_size"
)
assert
response
.
status_code
==
200
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment