Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0a24eb85
Unverified
Commit
0a24eb85
authored
Oct 28, 2024
by
Byron Hsu
Committed by
GitHub
Oct 28, 2024
Browse files
Fix update_weights deadlock for DP (#1825)
parent
3839be29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
67 additions
and
13 deletions
+67
-13
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+44
-13
test/srt/test_data_parallelism.py
test/srt/test_data_parallelism.py
+23
-0
No files found.
python/sglang/srt/managers/tokenizer_manager.py
View file @
0a24eb85
...
@@ -554,18 +554,43 @@ class TokenizerManager:
...
@@ -554,18 +554,43 @@ class TokenizerManager:
obj
.
load_format
=
self
.
server_args
.
load_format
obj
.
load_format
=
self
.
server_args
.
load_format
if
not
self
.
model_update_lock
.
locked
():
if
not
self
.
model_update_lock
.
locked
():
async
with
self
.
model_update_lock
:
# wait for the previous generation requests to finish
if
self
.
server_args
.
dp_size
==
1
:
while
len
(
self
.
rid_to_state
)
>
0
:
async
with
self
.
model_update_lock
:
await
asyncio
.
sleep
(
0.001
)
# wait for the previous generation requests to finish
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
while
len
(
self
.
rid_to_state
)
>
0
:
self
.
model_update_result
=
asyncio
.
Future
()
await
asyncio
.
sleep
(
0.001
)
result
=
await
self
.
model_update_result
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
if
result
.
success
:
self
.
model_update_result
=
asyncio
.
Future
()
self
.
server_args
.
model_path
=
obj
.
model_path
result
=
await
self
.
model_update_result
self
.
server_args
.
load_format
=
obj
.
load_format
if
result
.
success
:
self
.
model_path
=
obj
.
model_path
self
.
server_args
.
model_path
=
obj
.
model_path
return
result
.
success
,
result
.
message
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
return
result
.
success
,
result
.
message
else
:
# self.server_args.dp_size > 1
# There will be dp_size number of response from the detokenizer
async
with
self
.
model_update_lock
:
# wait for the previous generation requests to finish
while
len
(
self
.
rid_to_state
)
>
0
:
await
asyncio
.
sleep
(
0.001
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
model_update_result
=
asyncio
.
Future
()
self
.
model_update_tmp
=
[]
result
=
await
self
.
model_update_result
all_success
=
all
([
r
.
success
for
r
in
result
])
if
all_success
is
True
:
self
.
server_args
.
model_path
=
obj
.
model_path
self
.
server_args
.
load_format
=
obj
.
load_format
self
.
model_path
=
obj
.
model_path
all_message
=
[
r
.
message
for
r
in
result
]
all_message
=
" | "
.
join
(
all_message
)
return
all_success
,
all_message
else
:
else
:
return
False
,
"Another update is in progress. Please try again later."
return
False
,
"Another update is in progress. Please try again later."
...
@@ -600,7 +625,13 @@ class TokenizerManager:
...
@@ -600,7 +625,13 @@ class TokenizerManager:
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
]
=
await
self
.
recv_from_detokenizer
.
recv_pyobj
()
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
if
isinstance
(
recv_obj
,
UpdateWeightReqOutput
):
self
.
model_update_result
.
set_result
(
recv_obj
)
if
self
.
server_args
.
dp_size
==
1
:
self
.
model_update_result
.
set_result
(
recv_obj
)
else
:
# self.server_args.dp_size > 1
self
.
model_update_tmp
.
append
(
recv_obj
)
# set future if the all results are recevied
if
len
(
self
.
model_update_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
model_update_result
.
set_result
(
self
.
model_update_tmp
)
continue
continue
elif
isinstance
(
recv_obj
,
GetMemPoolSizeReqOutput
):
elif
isinstance
(
recv_obj
,
GetMemPoolSizeReqOutput
):
self
.
mem_pool_size
.
set_result
(
recv_obj
)
self
.
mem_pool_size
.
set_result
(
recv_obj
)
...
...
test/srt/test_data_parallelism.py
View file @
0a24eb85
import
time
import
unittest
import
unittest
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
import
requests
from
sglang.srt.utils
import
kill_child_process
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.run_eval
import
run_eval
from
sglang.test.run_eval
import
run_eval
from
sglang.test.test_utils
import
(
from
sglang.test.test_utils
import
(
...
@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase):
...
@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase):
metrics
=
run_eval
(
args
)
metrics
=
run_eval
(
args
)
assert
metrics
[
"score"
]
>=
0.65
assert
metrics
[
"score"
]
>=
0.65
def
test_update_weight
(
self
):
response
=
requests
.
post
(
self
.
base_url
+
"/update_weights"
,
json
=
{
"model_path"
:
DEFAULT_MODEL_NAME_FOR_TEST
},
)
# check if the response is 200
assert
response
.
status_code
==
200
# pause a few seconds then send again
time
.
sleep
(
5
)
response
=
requests
.
post
(
self
.
base_url
+
"/update_weights"
,
json
=
{
"model_path"
:
DEFAULT_MODEL_NAME_FOR_TEST
},
)
# check if the response is 200
assert
response
.
status_code
==
200
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment