"docs/source/vscode:/vscode.git/clone" did not exist on "d453d72dd158d76a41229722cb6e108c006bd78a"
Unverified Commit 0a24eb85 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Fix update_weights deadlock for DP (#1825)

parent 3839be29
...@@ -554,6 +554,8 @@ class TokenizerManager: ...@@ -554,6 +554,8 @@ class TokenizerManager:
obj.load_format = self.server_args.load_format obj.load_format = self.server_args.load_format
if not self.model_update_lock.locked(): if not self.model_update_lock.locked():
if self.server_args.dp_size == 1:
async with self.model_update_lock: async with self.model_update_lock:
# wait for the previous generation requests to finish # wait for the previous generation requests to finish
while len(self.rid_to_state) > 0: while len(self.rid_to_state) > 0:
...@@ -566,6 +568,29 @@ class TokenizerManager: ...@@ -566,6 +568,29 @@ class TokenizerManager:
self.server_args.load_format = obj.load_format self.server_args.load_format = obj.load_format
self.model_path = obj.model_path self.model_path = obj.model_path
return result.success, result.message return result.success, result.message
else: # self.server_args.dp_size > 1
# There will be dp_size number of response from the detokenizer
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
self.model_update_tmp = []
result = await self.model_update_result
all_success = all([r.success for r in result])
if all_success is True:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
else: else:
return False, "Another update is in progress. Please try again later." return False, "Another update is in progress. Please try again later."
...@@ -600,7 +625,13 @@ class TokenizerManager: ...@@ -600,7 +625,13 @@ class TokenizerManager:
] = await self.recv_from_detokenizer.recv_pyobj() ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput): if isinstance(recv_obj, UpdateWeightReqOutput):
if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj) self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput): elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.mem_pool_size.set_result(recv_obj) self.mem_pool_size.set_result(recv_obj)
......
import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase): ...@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.65 assert metrics["score"] >= 0.65
def test_update_weight(self):
response = requests.post(
self.base_url + "/update_weights",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
# check if the response is 200
assert response.status_code == 200
# pause a few seconds then send again
time.sleep(5)
response = requests.post(
self.base_url + "/update_weights",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
# check if the response is 200
assert response.status_code == 200
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment