Unverified Commit b2ed5c8e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny code cleanup in tokenizer_manager.py (#2586)

parent f46f394f
...@@ -22,7 +22,7 @@ import signal ...@@ -22,7 +22,7 @@ import signal
import sys import sys
import time import time
import uuid import uuid
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi import fastapi
import uvloop import uvloop
...@@ -173,6 +173,15 @@ class TokenizerManager: ...@@ -173,6 +173,15 @@ class TokenizerManager:
# Others # Others
self.gracefully_exit = False self.gracefully_exit = False
self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Metrics # Metrics
if self.enable_metrics: if self.enable_metrics:
...@@ -190,8 +199,7 @@ class TokenizerManager: ...@@ -190,8 +199,7 @@ class TokenizerManager:
): ):
created_time = time.time() created_time = time.time()
if self.to_create_loop: self.auto_create_handle_loop()
self.create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation: if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError( raise ValueError(
...@@ -440,8 +448,7 @@ class TokenizerManager: ...@@ -440,8 +448,7 @@ class TokenizerManager:
obj: UpdateWeightFromDiskReqInput, obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
if self.to_create_loop: self.auto_create_handle_loop()
self.create_handle_loop()
# default the load format to the server_args # default the load format to the server_args
if obj.load_format is None: if obj.load_format is None:
...@@ -456,7 +463,7 @@ class TokenizerManager: ...@@ -456,7 +463,7 @@ class TokenizerManager:
async def _wait_for_model_update_from_disk( async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str, int]: ) -> Tuple[bool, str]:
self.send_to_scheduler.send_pyobj(obj) self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future() self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
...@@ -485,15 +492,11 @@ class TokenizerManager: ...@@ -485,15 +492,11 @@ class TokenizerManager:
obj: InitWeightsUpdateGroupReqInput, obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
if self.to_create_loop: self.auto_create_handle_loop()
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
self.init_weights_update_group_result = asyncio.Future()
assert ( assert (
self.server_args.dp_size == 1 self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group" ), "dp_size must be 1 for init parameter update group"
result = await self.init_weights_update_group_result result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message return result.success, result.message
async def update_weights_from_distributed( async def update_weights_from_distributed(
...@@ -501,44 +504,32 @@ class TokenizerManager: ...@@ -501,44 +504,32 @@ class TokenizerManager:
obj: UpdateWeightsFromDistributedReqInput, obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None, request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]: ) -> Tuple[bool, str]:
if self.to_create_loop: self.auto_create_handle_loop()
self.create_handle_loop() assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
# This means that weight sync # This means that weight sync
# cannot run while requests are in progress. # cannot run while requests are in progress.
async with self.model_update_lock.writer_lock: async with self.model_update_lock.writer_lock:
self.send_to_scheduler.send_pyobj(obj) result = (await self.update_weights_from_distributed_communicator(obj))[0]
self.parameter_update_result: Awaitable[
UpdateWeightsFromDistributedReqOutput
] = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
result = await self.parameter_update_result
return result.success, result.message return result.success, result.message
async def get_weights_by_name( async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
): ):
if self.to_create_loop: self.auto_create_handle_loop()
self.create_handle_loop() results = await self.get_weights_by_name_communicator(obj)
all_parameters = [r.parameter for r in results]
self.send_to_scheduler.send_pyobj(obj)
self.get_weights_by_name_result = asyncio.Future()
if self.server_args.dp_size == 1: if self.server_args.dp_size == 1:
result = await self.get_weights_by_name_result return all_parameters[0]
return result.parameter
else: else:
self.get_weights_by_name_tmp = []
result = await self.get_weights_by_name_result
all_parameters = [r.parameter for r in result]
return all_parameters return all_parameters
async def open_session( async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
): ):
if self.to_create_loop: self.auto_create_handle_loop()
self.create_handle_loop()
session_id = uuid.uuid4().hex session_id = uuid.uuid4().hex
obj.session_id = session_id obj.session_id = session_id
...@@ -568,7 +559,7 @@ class TokenizerManager: ...@@ -568,7 +559,7 @@ class TokenizerManager:
background_tasks.add_task(abort_request) background_tasks.add_task(abort_request)
return background_tasks return background_tasks
def create_handle_loop(self): def auto_create_handle_loop(self):
if not self.to_create_loop: if not self.to_create_loop:
return return
...@@ -711,21 +702,14 @@ class TokenizerManager: ...@@ -711,21 +702,14 @@ class TokenizerManager:
assert ( assert (
self.server_args.dp_size == 1 self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group" ), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj) self.init_weights_update_group_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput): elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert ( assert (
self.server_args.dp_size == 1 self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed" ), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj) self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput): elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1: self.get_weights_by_name_communicator.handle_recv(recv_obj)
self.get_weights_by_name_result.set_result(recv_obj)
else:
self.get_weights_by_name_tmp.append(recv_obj)
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
else: else:
raise ValueError(f"Invalid object: {recv_obj=}") raise ValueError(f"Invalid object: {recv_obj=}")
...@@ -809,3 +793,28 @@ class SignalHandler: ...@@ -809,3 +793,28 @@ class SignalHandler:
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..." f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
) )
self.tokenizer_manager.gracefully_exit = True self.tokenizer_manager.gracefully_exit = True
T = TypeVar("T")
class _Communicator(Generic[T]):
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_future: Optional[asyncio.Future] = None
self._result_values: Optional[List[T]] = None
async def __call__(self, obj):
self._sender.send_pyobj(obj)
self._result_future = asyncio.Future()
self._result_values = []
await self._result_future
result_values = self._result_values
self._result_future = self._result_values = None
return result_values
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_future.set_result(None)
...@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request): ...@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
try: try:
ret = await tokenizer_manager.get_weights_by_name(obj, request) ret = await tokenizer_manager.get_weights_by_name(obj, request)
if ret is None: if ret is None:
return ORJSONResponse( return _create_error_response("Get parameter by name failed")
{"error": {"message": "Get parameter by name failed"}},
status_code=HTTPStatus.BAD_REQUEST,
)
else: else:
return ORJSONResponse(ret, status_code=200) return ORJSONResponse(ret, status_code=200)
except Exception as e: except Exception as e:
return ORJSONResponse( return _create_error_response(e)
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/open_session", methods=["GET", "POST"]) @app.api_route("/open_session", methods=["GET", "POST"])
...@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request): ...@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
session_id = await tokenizer_manager.open_session(obj, request) session_id = await tokenizer_manager.open_session(obj, request)
return session_id return session_id
except Exception as e: except Exception as e:
return ORJSONResponse( return _create_error_response(e)
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/close_session", methods=["GET", "POST"]) @app.api_route("/close_session", methods=["GET", "POST"])
...@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request): ...@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
await tokenizer_manager.close_session(obj, request) await tokenizer_manager.close_session(obj, request)
return Response(status_code=200) return Response(status_code=200)
except Exception as e: except Exception as e:
return ORJSONResponse( return _create_error_response(e)
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
# fastapi implicitly converts json in the request to obj (dataclass) # fastapi implicitly converts json in the request to obj (dataclass)
...@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request): ...@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return ret return ret
except ValueError as e: except ValueError as e:
logger.error(f"Error: {e}") logger.error(f"Error: {e}")
return ORJSONResponse( return _create_error_response(e)
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/encode", methods=["POST", "PUT"]) @app.api_route("/encode", methods=["POST", "PUT"])
...@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request): ...@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return ORJSONResponse( return _create_error_response(e)
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
@app.api_route("/classify", methods=["POST", "PUT"]) @app.api_route("/classify", methods=["POST", "PUT"])
...@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request): ...@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__() ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret return ret
except ValueError as e: except ValueError as e:
return ORJSONResponse( return _create_error_response(e)
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
##### OpenAI-compatible API endpoints ##### ##### OpenAI-compatible API endpoints #####
...@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str): ...@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
return await v1_retrieve_file_content(file_id) return await v1_retrieve_file_content(file_id)
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
def launch_engine( def launch_engine(
server_args: ServerArgs, server_args: ServerArgs,
): ):
...@@ -849,12 +840,10 @@ class Engine: ...@@ -849,12 +840,10 @@ class Engine:
group_name=group_name, group_name=group_name,
backend=backend, backend=backend,
) )
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(_init_group()) return loop.run_until_complete(
tokenizer_manager.init_weights_update_group(obj, None)
)
def update_weights_from_distributed(self, name, dtype, shape): def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source.""" """Update weights from distributed source."""
...@@ -863,22 +852,16 @@ class Engine: ...@@ -863,22 +852,16 @@ class Engine:
dtype=dtype, dtype=dtype,
shape=shape, shape=shape,
) )
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights()) return loop.run_until_complete(
tokenizer_manager.update_weights_from_distributed(obj, None)
)
def get_weights_by_name(self, name, truncate_size=100): def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name.""" """Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size) obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights()) return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
class Runtime: class Runtime:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment