Unverified Commit 0a24eb85 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Fix update_weights deadlock for DP (#1825)

parent 3839be29
...@@ -554,18 +554,43 @@ class TokenizerManager: ...@@ -554,18 +554,43 @@ class TokenizerManager:
obj.load_format = self.server_args.load_format obj.load_format = self.server_args.load_format
if not self.model_update_lock.locked(): if not self.model_update_lock.locked():
async with self.model_update_lock:
# wait for the previous generation requests to finish if self.server_args.dp_size == 1:
while len(self.rid_to_state) > 0: async with self.model_update_lock:
await asyncio.sleep(0.001) # wait for the previous generation requests to finish
self.send_to_scheduler.send_pyobj(obj) while len(self.rid_to_state) > 0:
self.model_update_result = asyncio.Future() await asyncio.sleep(0.001)
result = await self.model_update_result self.send_to_scheduler.send_pyobj(obj)
if result.success: self.model_update_result = asyncio.Future()
self.server_args.model_path = obj.model_path result = await self.model_update_result
self.server_args.load_format = obj.load_format if result.success:
self.model_path = obj.model_path self.server_args.model_path = obj.model_path
return result.success, result.message self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
return result.success, result.message
else: # self.server_args.dp_size > 1
# There will be dp_size number of response from the detokenizer
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
self.model_update_tmp = []
result = await self.model_update_result
all_success = all([r.success for r in result])
if all_success is True:
self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format
self.model_path = obj.model_path
all_message = [r.message for r in result]
all_message = " | ".join(all_message)
return all_success, all_message
else: else:
return False, "Another update is in progress. Please try again later." return False, "Another update is in progress. Please try again later."
...@@ -600,7 +625,13 @@ class TokenizerManager: ...@@ -600,7 +625,13 @@ class TokenizerManager:
] = await self.recv_from_detokenizer.recv_pyobj() ] = await self.recv_from_detokenizer.recv_pyobj()
if isinstance(recv_obj, UpdateWeightReqOutput): if isinstance(recv_obj, UpdateWeightReqOutput):
self.model_update_result.set_result(recv_obj) if self.server_args.dp_size == 1:
self.model_update_result.set_result(recv_obj)
else: # self.server_args.dp_size > 1
self.model_update_tmp.append(recv_obj)
# set future if the all results are recevied
if len(self.model_update_tmp) == self.server_args.dp_size:
self.model_update_result.set_result(self.model_update_tmp)
continue continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput): elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.mem_pool_size.set_result(recv_obj) self.mem_pool_size.set_result(recv_obj)
......
import time
import unittest import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests
from sglang.srt.utils import kill_child_process from sglang.srt.utils import kill_child_process
from sglang.test.run_eval import run_eval from sglang.test.run_eval import run_eval
from sglang.test.test_utils import ( from sglang.test.test_utils import (
...@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase): ...@@ -39,6 +42,26 @@ class TestDataParallelism(unittest.TestCase):
metrics = run_eval(args) metrics = run_eval(args)
assert metrics["score"] >= 0.65 assert metrics["score"] >= 0.65
def test_update_weight(self):
response = requests.post(
self.base_url + "/update_weights",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
# check if the response is 200
assert response.status_code == 200
# pause a few seconds then send again
time.sleep(5)
response = requests.post(
self.base_url + "/update_weights",
json={"model_path": DEFAULT_MODEL_NAME_FOR_TEST},
)
# check if the response is 200
assert response.status_code == 200
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment