Unverified Commit 74bc397b authored by Jun Duan's avatar Jun Duan Committed by GitHub
Browse files

[Core] Expose API endpoint `/is_sleeping` (#14312)


Signed-off-by: default avatarJun Duan <jun.duan.phd@outlook.com>
parent f58aea00
...@@ -28,5 +28,12 @@ def test_sleep_mode(): ...@@ -28,5 +28,12 @@ def test_sleep_mode():
response = requests.post(remote_server.url_for("/sleep"), response = requests.post(remote_server.url_for("/sleep"),
data={"level": "1"}) data={"level": "1"})
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("/wake_up")) response = requests.post(remote_server.url_for("/wake_up"))
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False
...@@ -1225,6 +1225,9 @@ class AsyncLLMEngine(EngineClient): ...@@ -1225,6 +1225,9 @@ class AsyncLLMEngine(EngineClient):
async def wake_up(self) -> None: async def wake_up(self) -> None:
self.engine.wake_up() self.engine.wake_up()
async def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request) self.engine.add_lora(lora_request)
......
...@@ -1948,6 +1948,9 @@ class LLMEngine: ...@@ -1948,6 +1948,9 @@ class LLMEngine:
"Sleep mode is not enabled in the model config") "Sleep mode is not enabled in the model config")
self.model_executor.wake_up() self.model_executor.wake_up()
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping
def check_health(self) -> None: def check_health(self) -> None:
if self.tokenizer: if self.tokenizer:
self.tokenizer.check_health() self.tokenizer.check_health()
......
...@@ -136,6 +136,18 @@ class RPCWakeUpRequest(Enum): ...@@ -136,6 +136,18 @@ class RPCWakeUpRequest(Enum):
WAKE_UP = 1 WAKE_UP = 1
@dataclass
class RPCIsSleepingRequest:
# Set the default value of request_id to a new UUID
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@dataclass
class RPCIsSleepingResponse:
request_id: str
is_sleeping: bool
@dataclass @dataclass
class RPCLoadAdapterRequest: class RPCLoadAdapterRequest:
lora_request: LoRARequest lora_request: LoRARequest
...@@ -151,10 +163,10 @@ class RPCAdapterLoadedResponse: ...@@ -151,10 +163,10 @@ class RPCAdapterLoadedResponse:
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest, RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest, RPCSleepRequest, RPCResetPrefixCacheRequest, RPCSleepRequest,
RPCWakeUpRequest] RPCWakeUpRequest, RPCIsSleepingRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError] RPCIsSleepingResponse, RPCError]
def ENGINE_DEAD_ERROR( def ENGINE_DEAD_ERROR(
......
...@@ -27,6 +27,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -27,6 +27,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T, IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError, RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
RPCIsSleepingResponse,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
RPCResetPrefixCacheRequest, RPCResetPrefixCacheRequest,
...@@ -246,7 +248,9 @@ class MQLLMEngineClient(EngineClient): ...@@ -246,7 +248,9 @@ class MQLLMEngineClient(EngineClient):
if queue is not None: if queue is not None:
queue.put_nowait(exception) queue.put_nowait(exception)
# Put each output into the appropriate queue. # Put each output into the appropriate queue.
elif isinstance(request_outputs, RPCAdapterLoadedResponse): elif isinstance(
request_outputs,
(RPCAdapterLoadedResponse, RPCIsSleepingResponse)):
self._add_output(request_outputs) self._add_output(request_outputs)
else: else:
for request_output in request_outputs: for request_output in request_outputs:
...@@ -256,7 +260,8 @@ class MQLLMEngineClient(EngineClient): ...@@ -256,7 +260,8 @@ class MQLLMEngineClient(EngineClient):
logger.debug("Shutting down MQLLMEngineClient output handler.") logger.debug("Shutting down MQLLMEngineClient output handler.")
def _add_output(self, request_output: Union[RequestOutput, def _add_output(self, request_output: Union[RequestOutput,
RPCAdapterLoadedResponse]): RPCAdapterLoadedResponse,
RPCIsSleepingResponse]):
queue = self.output_queues.get(request_output.request_id) queue = self.output_queues.get(request_output.request_id)
if queue is not None: if queue is not None:
queue.put_nowait(request_output) queue.put_nowait(request_output)
...@@ -696,6 +701,24 @@ class MQLLMEngineClient(EngineClient): ...@@ -696,6 +701,24 @@ class MQLLMEngineClient(EngineClient):
return await self._send_one_way_rpc_request( return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket) request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket)
async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""
request = RPCIsSleepingRequest()
queue: asyncio.Queue[Union[BaseException,
RPCIsSleepingResponse]] = asyncio.Queue()
self.output_queues[request.request_id] = queue
request_bytes = pickle.dumps(request)
await self.input_socket.send_multipart((request_bytes, ), copy=False)
request_output = await queue.get()
self.output_queues.pop(request.request_id)
if isinstance(request_output, BaseException):
raise request_output
return request_output.is_sleeping
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests # Uses the same I/O as generate requests
......
...@@ -18,6 +18,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -18,6 +18,8 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T, IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest, VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCAdapterLoadedResponse, RPCError, RPCAdapterLoadedResponse, RPCError,
RPCIsSleepingRequest,
RPCIsSleepingResponse,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
RPCResetPrefixCacheRequest, RPCResetPrefixCacheRequest,
...@@ -271,6 +273,8 @@ class MQLLMEngine: ...@@ -271,6 +273,8 @@ class MQLLMEngine:
self.sleep(request.value) self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest): elif isinstance(request, RPCWakeUpRequest):
self.wake_up() self.wake_up()
elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request)
else: else:
raise ValueError("Unknown RPCRequest Type: " raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}") f"{type(request)}")
...@@ -337,6 +341,12 @@ class MQLLMEngine: ...@@ -337,6 +341,12 @@ class MQLLMEngine:
self._send_outputs( self._send_outputs(
RPCAdapterLoadedResponse(request_id=request.request_id)) RPCAdapterLoadedResponse(request_id=request.request_id))
def _handle_is_sleeping_request(self, request: RPCIsSleepingRequest):
is_sleeping = self.is_sleeping()
self._send_outputs(
RPCIsSleepingResponse(request_id=request.request_id,
is_sleeping=is_sleeping))
def _health_check(self): def _health_check(self):
# Send unhealthy if engine has already errored # Send unhealthy if engine has already errored
if self._errored_with is not None: if self._errored_with is not None:
...@@ -406,6 +416,9 @@ class MQLLMEngine: ...@@ -406,6 +416,9 @@ class MQLLMEngine:
def wake_up(self) -> None: def wake_up(self) -> None:
self.engine.wake_up() self.engine.wake_up()
def is_sleeping(self) -> bool:
return self.engine.is_sleeping()
def signal_handler(*_) -> None: def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated") raise KeyboardInterrupt("MQLLMEngine terminated")
......
...@@ -288,6 +288,11 @@ class EngineClient(ABC): ...@@ -288,6 +288,11 @@ class EngineClient(ABC):
"""Wake up the engine""" """Wake up the engine"""
... ...
@abstractmethod
async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping"""
...
@abstractmethod @abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
......
...@@ -694,6 +694,12 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -694,6 +694,12 @@ if envs.VLLM_SERVER_DEV_MODE:
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
return Response(status_code=200) return Response(status_code=200)
@router.get("/is_sleeping")
async def is_sleeping(raw_request: Request):
logger.info("check whether the engine is sleeping")
is_sleeping = await engine_client(raw_request).is_sleeping()
return JSONResponse(content={"is_sleeping": is_sleeping})
@router.post("/invocations", dependencies=[Depends(validate_json_request)]) @router.post("/invocations", dependencies=[Depends(validate_json_request)])
async def invocations(raw_request: Request): async def invocations(raw_request: Request):
......
...@@ -407,6 +407,9 @@ class AsyncLLM(EngineClient): ...@@ -407,6 +407,9 @@ class AsyncLLM(EngineClient):
async def wake_up(self) -> None: async def wake_up(self) -> None:
await self.engine_core.wake_up_async() await self.engine_core.wake_up_async()
async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async()
async def add_lora(self, lora_request: LoRARequest) -> bool: async def add_lora(self, lora_request: LoRARequest) -> bool:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
return await self.engine_core.add_lora_async(lora_request) return await self.engine_core.add_lora_async(lora_request)
......
...@@ -253,6 +253,9 @@ class EngineCore: ...@@ -253,6 +253,9 @@ class EngineCore:
def wake_up(self): def wake_up(self):
self.model_executor.wake_up() self.model_executor.wake_up()
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping
def execute_dummy_batch(self): def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch") self.model_executor.collective_rpc("execute_dummy_batch")
......
...@@ -89,6 +89,9 @@ class EngineCoreClient(ABC): ...@@ -89,6 +89,9 @@ class EngineCoreClient(ABC):
def wake_up(self) -> None: def wake_up(self) -> None:
raise NotImplementedError raise NotImplementedError
def is_sleeping(self) -> bool:
raise NotImplementedError
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -128,6 +131,9 @@ class EngineCoreClient(ABC): ...@@ -128,6 +131,9 @@ class EngineCoreClient(ABC):
async def wake_up_async(self) -> None: async def wake_up_async(self) -> None:
raise NotImplementedError raise NotImplementedError
async def is_sleeping_async(self) -> bool:
raise NotImplementedError
async def abort_requests_async(self, request_ids: list[str]) -> None: async def abort_requests_async(self, request_ids: list[str]) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -182,6 +188,9 @@ class InprocClient(EngineCoreClient): ...@@ -182,6 +188,9 @@ class InprocClient(EngineCoreClient):
def wake_up(self) -> None: def wake_up(self) -> None:
self.engine_core.wake_up() self.engine_core.wake_up()
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self.engine_core.execute_dummy_batch() self.engine_core.execute_dummy_batch()
...@@ -433,6 +442,9 @@ class SyncMPClient(MPClient): ...@@ -433,6 +442,9 @@ class SyncMPClient(MPClient):
def wake_up(self) -> None: def wake_up(self) -> None:
self._call_utility("wake_up") self._call_utility("wake_up")
def is_sleeping(self) -> bool:
return self._call_utility("is_sleeping")
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
self._call_utility("execute_dummy_batch") self._call_utility("execute_dummy_batch")
...@@ -523,6 +535,9 @@ class AsyncMPClient(MPClient): ...@@ -523,6 +535,9 @@ class AsyncMPClient(MPClient):
async def wake_up_async(self) -> None: async def wake_up_async(self) -> None:
await self._call_utility_async("wake_up") await self._call_utility_async("wake_up")
async def is_sleeping_async(self) -> bool:
return await self._call_utility_async("is_sleeping")
async def execute_dummy_batch_async(self) -> None: async def execute_dummy_batch_async(self) -> None:
await self._call_utility_async("execute_dummy_batch") await self._call_utility_async("execute_dummy_batch")
......
...@@ -235,6 +235,9 @@ class LLMEngine: ...@@ -235,6 +235,9 @@ class LLMEngine:
def wake_up(self): def wake_up(self):
self.engine_core.wake_up() self.engine_core.wake_up()
def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping()
def get_tokenizer_group( def get_tokenizer_group(
self, self,
group_type: type[_G] = BaseTokenizerGroup, group_type: type[_G] = BaseTokenizerGroup,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment