Unverified Commit b2ed5c8e authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Tiny code cleanup in tokenizer_manager.py (#2586)

parent f46f394f
......@@ -22,7 +22,7 @@ import signal
import sys
import time
import uuid
from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
import fastapi
import uvloop
......@@ -173,6 +173,15 @@ class TokenizerManager:
# Others
self.gracefully_exit = False
self.init_weights_update_group_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_distributed_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
# Metrics
if self.enable_metrics:
......@@ -190,8 +199,7 @@ class TokenizerManager:
):
created_time = time.time()
if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()
if isinstance(obj, EmbeddingReqInput) and self.is_generation:
raise ValueError(
......@@ -440,8 +448,7 @@ class TokenizerManager:
obj: UpdateWeightFromDiskReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()
# default the load format to the server_args
if obj.load_format is None:
......@@ -456,7 +463,7 @@ class TokenizerManager:
async def _wait_for_model_update_from_disk(
self, obj: UpdateWeightFromDiskReqInput
) -> Tuple[bool, str, int]:
) -> Tuple[bool, str]:
self.send_to_scheduler.send_pyobj(obj)
self.model_update_result = asyncio.Future()
if self.server_args.dp_size == 1:
......@@ -485,15 +492,11 @@ class TokenizerManager:
obj: InitWeightsUpdateGroupReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
self.init_weights_update_group_result = asyncio.Future()
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
result = await self.init_weights_update_group_result
result = (await self.init_weights_update_group_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_distributed(
......@@ -501,44 +504,32 @@ class TokenizerManager:
obj: UpdateWeightsFromDistributedReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
# This means that weight sync
# cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
self.send_to_scheduler.send_pyobj(obj)
self.parameter_update_result: Awaitable[
UpdateWeightsFromDistributedReqOutput
] = asyncio.Future()
assert (
self.server_args.dp_size == 1
), "dp_size must be for update weights from distributed"
result = await self.parameter_update_result
result = (await self.update_weights_from_distributed_communicator(obj))[0]
return result.success, result.message
async def get_weights_by_name(
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
self.send_to_scheduler.send_pyobj(obj)
self.get_weights_by_name_result = asyncio.Future()
self.auto_create_handle_loop()
results = await self.get_weights_by_name_communicator(obj)
all_parameters = [r.parameter for r in results]
if self.server_args.dp_size == 1:
result = await self.get_weights_by_name_result
return result.parameter
return all_parameters[0]
else:
self.get_weights_by_name_tmp = []
result = await self.get_weights_by_name_result
all_parameters = [r.parameter for r in result]
return all_parameters
async def open_session(
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
):
if self.to_create_loop:
self.create_handle_loop()
self.auto_create_handle_loop()
session_id = uuid.uuid4().hex
obj.session_id = session_id
......@@ -568,7 +559,7 @@ class TokenizerManager:
background_tasks.add_task(abort_request)
return background_tasks
def create_handle_loop(self):
def auto_create_handle_loop(self):
if not self.to_create_loop:
return
......@@ -711,21 +702,14 @@ class TokenizerManager:
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for init parameter update group"
self.init_weights_update_group_result.set_result(recv_obj)
self.init_weights_update_group_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
assert (
self.server_args.dp_size == 1
), "dp_size must be 1 for update weights from distributed"
self.parameter_update_result.set_result(recv_obj)
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
if self.server_args.dp_size == 1:
self.get_weights_by_name_result.set_result(recv_obj)
else:
self.get_weights_by_name_tmp.append(recv_obj)
if len(self.get_weights_by_name_tmp) == self.server_args.dp_size:
self.get_weights_by_name_result.set_result(
self.get_weights_by_name_tmp
)
self.get_weights_by_name_communicator.handle_recv(recv_obj)
else:
raise ValueError(f"Invalid object: {recv_obj=}")
......@@ -809,3 +793,28 @@ class SignalHandler:
f"SIGTERM received. {signum=} {frame=}. Draining requests and shutting down..."
)
self.tokenizer_manager.gracefully_exit = True
T = TypeVar("T")
class _Communicator(Generic[T]):
def __init__(self, sender, fan_out: int):
self._sender = sender
self._fan_out = fan_out
self._result_future: Optional[asyncio.Future] = None
self._result_values: Optional[List[T]] = None
async def __call__(self, obj):
self._sender.send_pyobj(obj)
self._result_future = asyncio.Future()
self._result_values = []
await self._result_future
result_values = self._result_values
self._result_future = self._result_values = None
return result_values
def handle_recv(self, recv_obj: T):
self._result_values.append(recv_obj)
if len(self._result_values) == self._fan_out:
self._result_future.set_result(None)
......@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
try:
ret = await tokenizer_manager.get_weights_by_name(obj, request)
if ret is None:
return ORJSONResponse(
{"error": {"message": "Get parameter by name failed"}},
status_code=HTTPStatus.BAD_REQUEST,
)
return _create_error_response("Get parameter by name failed")
else:
return ORJSONResponse(ret, status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/open_session", methods=["GET", "POST"])
......@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
session_id = await tokenizer_manager.open_session(obj, request)
return session_id
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/close_session", methods=["GET", "POST"])
......@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
await tokenizer_manager.close_session(obj, request)
return Response(status_code=200)
except Exception as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
# fastapi implicitly converts json in the request to obj (dataclass)
......@@ -312,9 +303,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
return ret
except ValueError as e:
logger.error(f"Error: {e}")
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/encode", methods=["POST", "PUT"])
......@@ -325,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
@app.api_route("/classify", methods=["POST", "PUT"])
......@@ -338,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
return ret
except ValueError as e:
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
return _create_error_response(e)
##### OpenAI-compatible API endpoints #####
......@@ -416,6 +401,12 @@ async def retrieve_file_content(file_id: str):
return await v1_retrieve_file_content(file_id)
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
)
def launch_engine(
server_args: ServerArgs,
):
......@@ -849,12 +840,10 @@ class Engine:
group_name=group_name,
backend=backend,
)
async def _init_group():
return await tokenizer_manager.init_weights_update_group(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_init_group())
return loop.run_until_complete(
tokenizer_manager.init_weights_update_group(obj, None)
)
def update_weights_from_distributed(self, name, dtype, shape):
"""Update weights from distributed source."""
......@@ -863,22 +852,16 @@ class Engine:
dtype=dtype,
shape=shape,
)
async def _update_weights():
return await tokenizer_manager.update_weights_from_distributed(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_update_weights())
return loop.run_until_complete(
tokenizer_manager.update_weights_from_distributed(obj, None)
)
def get_weights_by_name(self, name, truncate_size=100):
"""Get weights by parameter name."""
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
async def _get_weights():
return await tokenizer_manager.get_weights_by_name(obj, None)
loop = asyncio.get_event_loop()
return loop.run_until_complete(_get_weights())
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
class Runtime:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment