Unverified Commit 64840dfa authored by 科英's avatar 科英 Committed by GitHub
Browse files

[Frontend] MQLLMEngine supports profiling. (#8761)

parent 28e1299e
...@@ -107,7 +107,13 @@ class RPCStartupResponse: ...@@ -107,7 +107,13 @@ class RPCStartupResponse:
tracing_enabled: bool tracing_enabled: bool
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest] class RPCUProfileRequest(Enum):
START_PROFILE = 1
STOP_PROFILE = 2
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError] REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]
......
...@@ -21,7 +21,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -21,7 +21,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest, RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse) RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.inputs import PromptType from vllm.inputs import PromptType
...@@ -38,10 +39,10 @@ logger = init_logger(__name__) ...@@ -38,10 +39,10 @@ logger = init_logger(__name__)
class MQClientClosedError(Exception): class MQClientClosedError(Exception):
"""Exception class raised when the client is used post-close. """Exception class raised when the client is used post-close.
The client can be closed, which closes the ZMQ context. This normally The client can be closed, which closes the ZMQ context. This normally
happens on server shutdown. In some cases, methods like abort and happens on server shutdown. In some cases, methods like abort and
do_log_stats will still be called and then try to open a socket, which do_log_stats will still be called and then try to open a socket, which
causes a ZMQError and creates a huge stack trace. causes a ZMQError and creates a huge stack trace.
So, we throw this error such that we can suppress it. So, we throw this error such that we can suppress it.
""" """
...@@ -345,7 +346,7 @@ class MQLLMEngineClient: ...@@ -345,7 +346,7 @@ class MQLLMEngineClient:
async def check_health(self): async def check_health(self):
""" """
The check health loop probes the health status of the The check health loop probes the health status of the
Engine's health every N seconds and sets _errored_with Engine's health every N seconds and sets _errored_with
if the engine is unhealthy. if the engine is unhealthy.
""" """
if self._errored_with is not None: if self._errored_with is not None:
...@@ -561,3 +562,15 @@ class MQLLMEngineClient: ...@@ -561,3 +562,15 @@ class MQLLMEngineClient:
await self.abort(request_id) await self.abort(request_id)
finally: finally:
self.output_queues.pop(request_id) self.output_queues.pop(request_id)
async def start_profile(self) -> None:
"""Start profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
async def stop_profile(self) -> None:
"""Stop profiling the engine"""
await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
...@@ -18,9 +18,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -18,9 +18,11 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest, RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse) RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
...@@ -249,6 +251,11 @@ class MQLLMEngine: ...@@ -249,6 +251,11 @@ class MQLLMEngine:
self._handle_process_request(request) self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest): elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request) self._handle_abort_request(request)
elif isinstance(request, RPCUProfileRequest):
if request == RPCUProfileRequest.START_PROFILE:
self.start_profile()
else:
self.stop_profile()
else: else:
raise ValueError("Unknown RPCRequest Type: " raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}") f"{type(request)}")
...@@ -356,6 +363,18 @@ class MQLLMEngine: ...@@ -356,6 +363,18 @@ class MQLLMEngine:
def _alive(self): def _alive(self):
self._last_alive_time = time.time() self._last_alive_time = time.time()
def start_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.start_profile()
else:
self.engine.model_executor._run_workers("start_profile")
def stop_profile(self) -> None:
if type(self.engine.model_executor) is GPUExecutor:
self.engine.model_executor.stop_profile()
else:
self.engine.model_executor._run_workers("stop_profile")
def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str): ipc_path: str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment