Unverified Commit ba811639 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core] add sleep and wake up endpoint and v1 support (#12987)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Signed-off-by: default avatarcennn <2523403608@qq.com>
Co-authored-by: default avatarcennn <2523403608@qq.com>
parent 0d243f2a
...@@ -118,14 +118,16 @@ def test_cumem_with_cudagraph(): ...@@ -118,14 +118,16 @@ def test_cumem_with_cudagraph():
@fork_new_process_for_each_test @fork_new_process_for_each_test
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model", "model, use_v1",
[ [
# sleep mode with safetensors # sleep mode with safetensors
f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", (f"{MODEL_WEIGHTS_S3_BUCKET}/Llama-3.2-1B", True),
# sleep mode with pytorch checkpoint # sleep mode with pytorch checkpoint
"facebook/opt-125m" ("facebook/opt-125m", False),
]) ])
def test_end_to_end(model): def test_end_to_end(model: str, use_v1: bool):
import os
os.environ["VLLM_USE_V1"] = "1" if use_v1 else "0"
free, total = torch.cuda.mem_get_info() free, total = torch.cuda.mem_get_info()
used_bytes_baseline = total - free # in case other process is running used_bytes_baseline = total - free # in case other process is running
load_format = LoadFormat.AUTO load_format = LoadFormat.AUTO
...@@ -152,3 +154,5 @@ def test_end_to_end(model): ...@@ -152,3 +154,5 @@ def test_end_to_end(model):
# cmp output # cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text assert output[0].outputs[0].text == output2[0].outputs[0].text
del os.environ["VLLM_USE_V1"]
# SPDX-License-Identifier: Apache-2.0
import requests
from ...utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B"
def test_sleep_mode():
# dtype, max-len etc set so that this can run in CI
args = [
"--dtype",
"bfloat16",
"--max-model-len",
"8192",
"--max-num-seqs",
"128",
"--enable-sleep-mode",
]
with RemoteOpenAIServer(MODEL_NAME,
args,
env_dict={
"VLLM_SERVER_DEV_MODE": "1",
"CUDA_VISIBLE_DEVICES": "0"
}) as remote_server:
response = requests.post(remote_server.url_for("/sleep"),
data={"level": "1"})
assert response.status_code == 200
response = requests.post(remote_server.url_for("/wake_up"))
assert response.status_code == 200
...@@ -1187,6 +1187,12 @@ class AsyncLLMEngine(EngineClient): ...@@ -1187,6 +1187,12 @@ class AsyncLLMEngine(EngineClient):
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self) -> None:
self.engine.reset_prefix_cache() self.engine.reset_prefix_cache()
async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
async def wake_up(self) -> None:
self.engine.wake_up()
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
self.engine.add_lora(lora_request) self.engine.add_lora(lora_request)
......
...@@ -127,6 +127,15 @@ class RPCResetPrefixCacheRequest(Enum): ...@@ -127,6 +127,15 @@ class RPCResetPrefixCacheRequest(Enum):
RESET_PREFIX_CACHE = 1 RESET_PREFIX_CACHE = 1
class RPCSleepRequest(Enum):
SLEEP_LEVEL_1 = 1
SLEEP_LEVEL_2 = 2
class RPCWakeUpRequest(Enum):
WAKE_UP = 1
@dataclass @dataclass
class RPCLoadAdapterRequest: class RPCLoadAdapterRequest:
lora_request: LoRARequest lora_request: LoRARequest
...@@ -141,7 +150,8 @@ class RPCAdapterLoadedResponse: ...@@ -141,7 +150,8 @@ class RPCAdapterLoadedResponse:
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest, RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest, RPCLoadAdapterRequest, RPCUProfileRequest, RPCLoadAdapterRequest,
RPCResetPrefixCacheRequest] RPCResetPrefixCacheRequest, RPCSleepRequest,
RPCWakeUpRequest]
REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse, REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCAdapterLoadedResponse,
RPCError] RPCError]
......
...@@ -31,8 +31,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -31,8 +31,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
RPCResetPrefixCacheRequest, RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse, RPCSleepRequest, RPCStartupRequest,
RPCUProfileRequest) RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
from vllm.engine.protocol import EngineClient from vllm.engine.protocol import EngineClient
# yapf: enable # yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT from vllm.envs import VLLM_RPC_TIMEOUT
...@@ -685,6 +686,16 @@ class MQLLMEngineClient(EngineClient): ...@@ -685,6 +686,16 @@ class MQLLMEngineClient(EngineClient):
request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE, request=RPCResetPrefixCacheRequest.RESET_PREFIX_CACHE,
socket=self.input_socket) socket=self.input_socket)
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine for a given level"""
return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self) -> None:
"""Wake up the engine"""
return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket)
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
# Uses the same I/O as generate requests # Uses the same I/O as generate requests
......
...@@ -20,8 +20,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT, ...@@ -20,8 +20,9 @@ from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
RPCLoadAdapterRequest, RPCLoadAdapterRequest,
RPCProcessRequest, RPCProcessRequest,
RPCResetPrefixCacheRequest, RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse, RPCSleepRequest, RPCStartupRequest,
RPCUProfileRequest) RPCStartupResponse,
RPCUProfileRequest, RPCWakeUpRequest)
# yapf: enable # yapf: enable
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
...@@ -242,6 +243,10 @@ class MQLLMEngine: ...@@ -242,6 +243,10 @@ class MQLLMEngine:
self._handle_load_adapter_request(request) self._handle_load_adapter_request(request)
elif isinstance(request, RPCResetPrefixCacheRequest): elif isinstance(request, RPCResetPrefixCacheRequest):
self.reset_prefix_cache() self.reset_prefix_cache()
elif isinstance(request, RPCSleepRequest):
self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest):
self.wake_up()
else: else:
raise ValueError("Unknown RPCRequest Type: " raise ValueError("Unknown RPCRequest Type: "
f"{type(request)}") f"{type(request)}")
...@@ -369,6 +374,12 @@ class MQLLMEngine: ...@@ -369,6 +374,12 @@ class MQLLMEngine:
def reset_prefix_cache(self) -> bool: def reset_prefix_cache(self) -> bool:
return self.engine.reset_prefix_cache() return self.engine.reset_prefix_cache()
def sleep(self, level: int = 1) -> None:
self.engine.sleep(level)
def wake_up(self) -> None:
self.engine.wake_up()
def signal_handler(*_) -> None: def signal_handler(*_) -> None:
raise KeyboardInterrupt("MQLLMEngine terminated") raise KeyboardInterrupt("MQLLMEngine terminated")
......
...@@ -278,6 +278,16 @@ class EngineClient(ABC): ...@@ -278,6 +278,16 @@ class EngineClient(ABC):
"""Reset the prefix cache""" """Reset the prefix cache"""
... ...
@abstractmethod
async def sleep(self, level: int = 1) -> None:
"""Sleep the engine"""
...
@abstractmethod
async def wake_up(self) -> None:
"""Wake up the engine"""
...
@abstractmethod @abstractmethod
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
......
...@@ -625,6 +625,24 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -625,6 +625,24 @@ if envs.VLLM_SERVER_DEV_MODE:
await engine_client(raw_request).reset_prefix_cache() await engine_client(raw_request).reset_prefix_cache()
return Response(status_code=200) return Response(status_code=200)
@router.post("/sleep")
async def sleep(raw_request: Request):
# get POST params
level = raw_request.query_params.get("level", "1")
logger.info("sleep the engine with level %s", level)
await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/wake_up")
async def wake_up(raw_request: Request):
logger.info("wake up the engine")
await engine_client(raw_request).wake_up()
# FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response.
return Response(status_code=200)
@router.post("/invocations", dependencies=[Depends(validate_json_request)]) @router.post("/invocations", dependencies=[Depends(validate_json_request)])
async def invocations(raw_request: Request): async def invocations(raw_request: Request):
......
...@@ -295,6 +295,7 @@ class OpenAIServingTranscription(OpenAIServing): ...@@ -295,6 +295,7 @@ class OpenAIServingTranscription(OpenAIServing):
# TODO(rob): figure out a way to pipe streaming in. # TODO(rob): figure out a way to pipe streaming in.
# Non-streaming response. # Non-streaming response.
try: try:
assert result_generator is not None
async for op in result_generator: async for op in result_generator:
result = op result = op
return TranscriptionResponse(text=result.outputs[0].text) return TranscriptionResponse(text=result.outputs[0].text)
......
...@@ -361,6 +361,12 @@ class AsyncLLM(EngineClient): ...@@ -361,6 +361,12 @@ class AsyncLLM(EngineClient):
async def reset_prefix_cache(self) -> None: async def reset_prefix_cache(self) -> None:
await self.engine_core.reset_prefix_cache_async() await self.engine_core.reset_prefix_cache_async()
async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level)
async def wake_up(self) -> None:
await self.engine_core.wake_up_async()
async def add_lora(self, lora_request: LoRARequest) -> None: async def add_lora(self, lora_request: LoRARequest) -> None:
"""Load a new LoRA adapter into the engine for future requests.""" """Load a new LoRA adapter into the engine for future requests."""
await self.engine_core.add_lora_async(lora_request) await self.engine_core.add_lora_async(lora_request)
......
...@@ -213,6 +213,12 @@ class EngineCore: ...@@ -213,6 +213,12 @@ class EngineCore:
def reset_prefix_cache(self): def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache() self.scheduler.reset_prefix_cache()
def sleep(self, level: int = 1):
self.model_executor.sleep(level)
def wake_up(self):
self.model_executor.wake_up()
def add_lora(self, lora_request: LoRARequest) -> None: def add_lora(self, lora_request: LoRARequest) -> None:
self.model_executor.add_lora(lora_request) self.model_executor.add_lora(lora_request)
......
...@@ -81,6 +81,12 @@ class EngineCoreClient(ABC): ...@@ -81,6 +81,12 @@ class EngineCoreClient(ABC):
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
raise NotImplementedError raise NotImplementedError
def sleep(self, level: int = 1) -> None:
raise NotImplementedError
def wake_up(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -99,6 +105,12 @@ class EngineCoreClient(ABC): ...@@ -99,6 +105,12 @@ class EngineCoreClient(ABC):
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError raise NotImplementedError
async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError
async def wake_up_async(self) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -138,6 +150,12 @@ class InprocClient(EngineCoreClient): ...@@ -138,6 +150,12 @@ class InprocClient(EngineCoreClient):
def reset_prefix_cache(self) -> None: def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
def wake_up(self) -> None:
self.engine_core.wake_up()
def add_lora(self, lora_request: LoRARequest) -> None: def add_lora(self, lora_request: LoRARequest) -> None:
self.engine_core.add_lora(lora_request) self.engine_core.add_lora(lora_request)
...@@ -307,6 +325,12 @@ class SyncMPClient(MPClient): ...@@ -307,6 +325,12 @@ class SyncMPClient(MPClient):
def add_lora(self, lora_request: LoRARequest) -> None: def add_lora(self, lora_request: LoRARequest) -> None:
self._call_utility("add_lora", lora_request) self._call_utility("add_lora", lora_request)
def sleep(self, level: int = 1) -> None:
self._call_utility("sleep", level)
def wake_up(self) -> None:
self._call_utility("wake_up")
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
...@@ -384,5 +408,11 @@ class AsyncMPClient(MPClient): ...@@ -384,5 +408,11 @@ class AsyncMPClient(MPClient):
async def reset_prefix_cache_async(self) -> None: async def reset_prefix_cache_async(self) -> None:
await self._call_utility_async("reset_prefix_cache") await self._call_utility_async("reset_prefix_cache")
async def sleep_async(self, level: int = 1) -> None:
await self._call_utility_async("sleep", level)
async def wake_up_async(self) -> None:
await self._call_utility_async("wake_up")
async def add_lora_async(self, lora_request: LoRARequest) -> None: async def add_lora_async(self, lora_request: LoRARequest) -> None:
await self._call_utility_async("add_lora", lora_request) await self._call_utility_async("add_lora", lora_request)
...@@ -169,6 +169,12 @@ class LLMEngine: ...@@ -169,6 +169,12 @@ class LLMEngine:
def reset_prefix_cache(self): def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache() self.engine_core.reset_prefix_cache()
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
def wake_up(self):
self.engine_core.wake_up()
def get_tokenizer_group( def get_tokenizer_group(
self, self,
group_type: Type[_G] = BaseTokenizerGroup, group_type: Type[_G] = BaseTokenizerGroup,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment