"git@developer.sourcefind.cn:OpenDAS/torch-sparce.git" did not exist on "fcf88282f205afc420b9656bf2cc57a16ffa0fea"
Unverified Commit 680cad20 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

fix get_memory_pool_size deadlock for DP (#1830)

parent 0a24eb85
...@@ -539,9 +539,22 @@ class TokenizerManager: ...@@ -539,9 +539,22 @@ class TokenizerManager:
self.create_handle_loop() self.create_handle_loop()
req = GetMemPoolSizeReq() req = GetMemPoolSizeReq()
self.send_to_scheduler.send_pyobj(req) ret = None
self.mem_pool_size = asyncio.Future()
return await self.mem_pool_size if self.server_args.dp_size == 1:
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
res = await self.mem_pool_size
ret = res.size
else: # self.server_args.dp_size > 1
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
self.mem_pool_size_tmp = []
res = await self.mem_pool_size
ret = [r.size for r in res]
return ret
async def update_weights( async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
...@@ -634,7 +647,13 @@ class TokenizerManager: ...@@ -634,7 +647,13 @@ class TokenizerManager:
self.model_update_result.set_result(self.model_update_tmp) self.model_update_result.set_result(self.model_update_tmp)
continue continue
elif isinstance(recv_obj, GetMemPoolSizeReqOutput): elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
self.mem_pool_size.set_result(recv_obj) if self.server_args.dp_size == 1:
self.mem_pool_size.set_result(recv_obj)
else: # self.sever_args.dp_size > 1
self.mem_pool_size_tmp.append(recv_obj)
# set future if the all results are received
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
continue continue
assert isinstance( assert isinstance(
......
...@@ -177,7 +177,8 @@ async def get_memory_pool_size(): ...@@ -177,7 +177,8 @@ async def get_memory_pool_size():
"""Get the memory pool size in number of tokens""" """Get the memory pool size in number of tokens"""
try: try:
ret = await tokenizer_manager.get_memory_pool_size() ret = await tokenizer_manager.get_memory_pool_size()
return ret.size
return ret
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
......
...@@ -62,6 +62,15 @@ class TestDataParallelism(unittest.TestCase): ...@@ -62,6 +62,15 @@ class TestDataParallelism(unittest.TestCase):
# check if the response is 200 # check if the response is 200
assert response.status_code == 200 assert response.status_code == 200
def test_get_memory_pool_size(self):
response = requests.get(self.base_url + "/get_memory_pool_size")
assert response.status_code == 200
time.sleep(5)
response = requests.get(self.base_url + "/get_memory_pool_size")
assert response.status_code == 200
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment