Unverified Commit 438526a8 authored by Byron Hsu's avatar Byron Hsu Committed by GitHub
Browse files

Refactor tokenizer manager (#1846)

parent f7102fbd
...@@ -549,22 +549,18 @@ class TokenizerManager: ...@@ -549,22 +549,18 @@ class TokenizerManager:
self.create_handle_loop() self.create_handle_loop()
req = GetMemPoolSizeReq() req = GetMemPoolSizeReq()
ret = None
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
res = await self.mem_pool_size res = await self.mem_pool_size
ret = res.size return res.size
else: # self.server_args.dp_size > 1 else: # self.server_args.dp_size > 1
self.send_to_scheduler.send_pyobj(req)
self.mem_pool_size = asyncio.Future()
self.mem_pool_size_tmp = [] self.mem_pool_size_tmp = []
res = await self.mem_pool_size res = await self.mem_pool_size
ret = [r.size for r in res] ret = [r.size for r in res]
return ret
return ret
async def update_weights( async def update_weights(
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
...@@ -578,29 +574,21 @@ class TokenizerManager: ...@@ -578,29 +574,21 @@ class TokenizerManager:
if not self.model_update_lock.locked(): if not self.model_update_lock.locked():
if self.server_args.dp_size == 1: async with self.model_update_lock:
async with self.model_update_lock: # wait for the previous generation requests to finish
# wait for the previous generation requests to finish while len(self.rid_to_state) > 0:
while len(self.rid_to_state) > 0: await asyncio.sleep(0.001)
await asyncio.sleep(0.001) self.send_to_scheduler.send_pyobj(obj)
self.send_to_scheduler.send_pyobj(obj) self.model_update_result = asyncio.Future()
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
result = await self.model_update_result result = await self.model_update_result
if result.success: if result.success:
self.server_args.model_path = obj.model_path self.server_args.model_path = obj.model_path
self.server_args.load_format = obj.load_format self.server_args.load_format = obj.load_format
self.model_path = obj.model_path self.model_path = obj.model_path
return result.success, result.message return result.success, result.message
else: # self.server_args.dp_size > 1
else: # self.server_args.dp_size > 1
# There will be dp_size number of response from the detokenizer
async with self.model_update_lock:
# wait for the previous generation requests to finish
while len(self.rid_to_state) > 0:
await asyncio.sleep(0.001)
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
self.model_update_tmp = [] self.model_update_tmp = []
result = await self.model_update_result result = await self.model_update_result
...@@ -611,8 +599,7 @@ class TokenizerManager: ...@@ -611,8 +599,7 @@ class TokenizerManager:
self.model_path = obj.model_path self.model_path = obj.model_path
all_message = [r.message for r in result] all_message = [r.message for r in result]
all_message = " | ".join(all_message) all_message = " | ".join(all_message)
return all_success, all_message
return all_success, all_message
else: else:
return False, "Another update is in progress. Please try again later." return False, "Another update is in progress. Please try again later."
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment