Unverified Commit ddb94c26 authored by Eric Tang's avatar Eric Tang Committed by GitHub
Browse files

[core] Add tags parameter to wake_up() (#15500)


Signed-off-by: default avatarEric <erictang000@gmail.com>
parent 90969fb3
...@@ -155,6 +155,24 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool): ...@@ -155,6 +155,24 @@ def test_end_to_end(monkeypatch: pytest.MonkeyPatch, model: str, use_v1: bool):
llm.wake_up() llm.wake_up()
output2 = llm.generate(prompt, sampling_params) output2 = llm.generate(prompt, sampling_params)
# cmp output # cmp output
assert output[0].outputs[0].text == output2[0].outputs[0].text assert output[0].outputs[0].text == output2[0].outputs[0].text
llm.sleep(level=1)
llm.wake_up(tags=["weights"])
free_gpu_bytes_wake_up_w, total = torch.cuda.mem_get_info()
used_bytes = total - free_gpu_bytes_wake_up_w - used_bytes_baseline
# should just reallocate memory for weights (1B model, ~2GiB weights)
if use_v1:
assert used_bytes < 10 * GiB_bytes
else:
assert used_bytes < 6 * GiB_bytes
# now allocate kv cache memory
llm.wake_up(tags=["kv_cache"])
output3 = llm.generate(prompt, sampling_params)
# cmp output
assert output[0].outputs[0].text == output3[0].outputs[0].text
...@@ -25,16 +25,37 @@ def test_sleep_mode(): ...@@ -25,16 +25,37 @@ def test_sleep_mode():
"VLLM_SERVER_DEV_MODE": "1", "VLLM_SERVER_DEV_MODE": "1",
"CUDA_VISIBLE_DEVICES": "0" "CUDA_VISIBLE_DEVICES": "0"
}) as remote_server: }) as remote_server:
response = requests.post(remote_server.url_for("sleep"),
params={"level": "1"})
assert response.status_code == 200
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("wake_up"))
assert response.status_code == 200
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200
assert response.json().get("is_sleeping") is False
response = requests.post(remote_server.url_for("/sleep"), # test wake up with tags
response = requests.post(remote_server.url_for("sleep"),
params={"level": "1"}) params={"level": "1"})
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping"))
response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["weights"]})
assert response.status_code == 200
# is sleeping should be false after waking up any part of the engine
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json().get("is_sleeping") is True assert response.json().get("is_sleeping") is True
response = requests.post(remote_server.url_for("/wake_up")) response = requests.post(remote_server.url_for("wake_up"),
params={"tags": ["kv_cache"]})
assert response.status_code == 200 assert response.status_code == 200
response = requests.get(remote_server.url_for("/is_sleeping"))
response = requests.get(remote_server.url_for("is_sleeping"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json().get("is_sleeping") is False assert response.json().get("is_sleeping") is False
...@@ -208,22 +208,28 @@ class CuMemAllocator: ...@@ -208,22 +208,28 @@ class CuMemAllocator:
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None) -> None:
""" """
Wake up the allocator from sleep mode. Wake up the allocator from sleep mode.
All data that is previously offloaded will be loaded back to GPU All data that is previously offloaded will be loaded back to GPU
memory, and the rest of the data will have empty memory.""" memory, and the rest of the data will have empty memory.
:param tags: The tags of the memory allocation that will be loaded
back to GPU memory. If None, all memory allocation will be loaded
back to GPU memory.
"""
for ptr, data in self.pointer_to_data.items(): for ptr, data in self.pointer_to_data.items():
handle = data.handle if tags is None or data.tag in tags:
create_and_map(handle) handle = data.handle
if data.cpu_backup_tensor is not None: create_and_map(handle)
cpu_backup_tensor = data.cpu_backup_tensor if data.cpu_backup_tensor is not None:
if cpu_backup_tensor is not None: cpu_backup_tensor = data.cpu_backup_tensor
size_in_bytes = cpu_backup_tensor.numel( if cpu_backup_tensor is not None:
) * cpu_backup_tensor.element_size() size_in_bytes = cpu_backup_tensor.numel(
cpu_ptr = cpu_backup_tensor.data_ptr() ) * cpu_backup_tensor.element_size()
libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes) cpu_ptr = cpu_backup_tensor.data_ptr()
data.cpu_backup_tensor = None libcudart.cudaMemcpy(ptr, cpu_ptr, size_in_bytes)
data.cpu_backup_tensor = None
@contextmanager @contextmanager
def use_memory_pool(self, tag: Optional[str] = None): def use_memory_pool(self, tag: Optional[str] = None):
......
...@@ -1225,8 +1225,8 @@ class AsyncLLMEngine(EngineClient): ...@@ -1225,8 +1225,8 @@ class AsyncLLMEngine(EngineClient):
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
self.engine.sleep(level) self.engine.sleep(level)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up() self.engine.wake_up(tags)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return self.engine.is_sleeping() return self.engine.is_sleeping()
......
...@@ -1938,10 +1938,10 @@ class LLMEngine: ...@@ -1938,10 +1938,10 @@ class LLMEngine:
"Sleep mode is not enabled in the model config") "Sleep mode is not enabled in the model config")
self.model_executor.sleep(level=level) self.model_executor.sleep(level=level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
assert self.vllm_config.model_config.enable_sleep_mode, ( assert self.vllm_config.model_config.enable_sleep_mode, (
"Sleep mode is not enabled in the model config") "Sleep mode is not enabled in the model config")
self.model_executor.wake_up() self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
......
...@@ -133,8 +133,9 @@ class RPCSleepRequest(Enum): ...@@ -133,8 +133,9 @@ class RPCSleepRequest(Enum):
SLEEP_LEVEL_2 = 2 SLEEP_LEVEL_2 = 2
class RPCWakeUpRequest(Enum): @dataclass
WAKE_UP = 1 class RPCWakeUpRequest:
tags: Optional[list[str]] = None
@dataclass @dataclass
......
...@@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient): ...@@ -697,10 +697,10 @@ class MQLLMEngineClient(EngineClient):
return await self._send_one_way_rpc_request( return await self._send_one_way_rpc_request(
request=RPCSleepRequest(level), socket=self.input_socket) request=RPCSleepRequest(level), socket=self.input_socket)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine""" """Wake up the engine"""
return await self._send_one_way_rpc_request( return await self._send_one_way_rpc_request(
request=RPCWakeUpRequest.WAKE_UP, socket=self.input_socket) request=RPCWakeUpRequest(tags), socket=self.input_socket)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
"""Check whether the engine is sleeping""" """Check whether the engine is sleeping"""
......
...@@ -274,7 +274,7 @@ class MQLLMEngine: ...@@ -274,7 +274,7 @@ class MQLLMEngine:
elif isinstance(request, RPCSleepRequest): elif isinstance(request, RPCSleepRequest):
self.sleep(request.value) self.sleep(request.value)
elif isinstance(request, RPCWakeUpRequest): elif isinstance(request, RPCWakeUpRequest):
self.wake_up() self.wake_up(request.tags)
elif isinstance(request, RPCIsSleepingRequest): elif isinstance(request, RPCIsSleepingRequest):
self._handle_is_sleeping_request(request) self._handle_is_sleeping_request(request)
else: else:
...@@ -415,8 +415,8 @@ class MQLLMEngine: ...@@ -415,8 +415,8 @@ class MQLLMEngine:
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine.sleep(level) self.engine.sleep(level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine.wake_up() self.engine.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine.is_sleeping() return self.engine.is_sleeping()
......
...@@ -282,7 +282,7 @@ class EngineClient(ABC): ...@@ -282,7 +282,7 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
"""Wake up the engine""" """Wake up the engine"""
... ...
......
...@@ -1200,26 +1200,35 @@ class LLM: ...@@ -1200,26 +1200,35 @@ class LLM:
The caller should guarantee that no requests are being processed The caller should guarantee that no requests are being processed
during the sleep period, before `wake_up` is called. during the sleep period, before `wake_up` is called.
:param level: The sleep level. Level 1 sleep will offload the model Args:
weights and discard the kv cache. The content of kv cache is level: The sleep level. Level 1 sleep will offload the model
forgotten. Level 1 sleep is good for sleeping and waking up the weights and discard the kv cache. The content of kv cache
engine to run the same model again. The model weights are backed is forgotten. Level 1 sleep is good for sleeping and waking
up in CPU memory. Please make sure there's enough CPU memory to up the engine to run the same model again. The model weights
store the model weights. Level 2 sleep will discard both the model are backed up in CPU memory. Please make sure there's enough
weights and the kv cache. The content of both the model weights CPU memory to store the model weights. Level 2 sleep will
and kv cache is forgotten. Level 2 sleep is good for sleeping and discard both the model weights and the kv cache. The content
waking up the engine to run a different model or update the model, of both the model weights and kv cache is forgotten. Level 2
where previous model weights are not needed. It reduces CPU memory sleep is good for sleeping and waking up the engine to run a
pressure. different model or update the model, where previous model
weights are not needed. It reduces CPU memory pressure.
""" """
self.reset_prefix_cache() self.reset_prefix_cache()
self.llm_engine.sleep(level=level) self.llm_engine.sleep(level=level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
""" """
Wake up the engine from sleep mode. See the :meth:`sleep` method Wake up the engine from sleep mode. See the :meth:`sleep` method
for more details.""" for more details.
self.llm_engine.wake_up()
Args:
tags: An optional list of tags to reallocate the engine memory
for specific memory allocations. Values must be in
("weights", "kv_cache",). If None, all memory is reallocated.
wake_up should be called with all tags (or None) before the
engine is used again.
"""
self.llm_engine.wake_up(tags)
# LEGACY # LEGACY
def _convert_v1_inputs( def _convert_v1_inputs(
......
...@@ -705,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -705,7 +705,6 @@ if envs.VLLM_SERVER_DEV_MODE:
async def sleep(raw_request: Request): async def sleep(raw_request: Request):
# get POST params # get POST params
level = raw_request.query_params.get("level", "1") level = raw_request.query_params.get("level", "1")
logger.info("sleep the engine with level %s", level)
await engine_client(raw_request).sleep(int(level)) await engine_client(raw_request).sleep(int(level))
# FIXME: in v0 with frontend multiprocessing, the sleep command # FIXME: in v0 with frontend multiprocessing, the sleep command
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
...@@ -713,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -713,8 +712,12 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/wake_up") @router.post("/wake_up")
async def wake_up(raw_request: Request): async def wake_up(raw_request: Request):
logger.info("wake up the engine") tags = raw_request.query_params.getlist("tags")
await engine_client(raw_request).wake_up() if tags == []:
# set to None to wake up all tags if no tags are provided
tags = None
logger.info("wake up the engine with tags: %s", tags)
await engine_client(raw_request).wake_up(tags)
# FIXME: in v0 with frontend multiprocessing, the wake-up command # FIXME: in v0 with frontend multiprocessing, the wake-up command
# is sent but does not finish yet when we return a response. # is sent but does not finish yet when we return a response.
return Response(status_code=200) return Response(status_code=200)
......
...@@ -51,6 +51,7 @@ class ExecutorBase(ABC): ...@@ -51,6 +51,7 @@ class ExecutorBase(ABC):
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self._init_executor() self._init_executor()
self.is_sleeping = False self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod @abstractmethod
def _init_executor(self) -> None: def _init_executor(self) -> None:
...@@ -204,20 +205,34 @@ class ExecutorBase(ABC): ...@@ -204,20 +205,34 @@ class ExecutorBase(ABC):
time_before_sleep = time.perf_counter() time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level)) self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter() time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.", logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep) time_after_sleep - time_before_sleep)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping: if not self.is_sleeping:
logger.warning("Executor is not sleeping.") logger.warning("Executor is not sleeping.")
return return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter() time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up") self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter() time_after_wakeup = time.perf_counter()
self.is_sleeping = False logger.info("It took %.6f seconds to wake up tags %s.",
logger.info("It took %.6f seconds to wake up.", time_after_wakeup - time_before_wakeup,
time_after_wakeup - time_before_wakeup) tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state( def save_sharded_state(
self, self,
......
...@@ -424,8 +424,8 @@ class AsyncLLM(EngineClient): ...@@ -424,8 +424,8 @@ class AsyncLLM(EngineClient):
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
await self.engine_core.sleep_async(level) await self.engine_core.sleep_async(level)
async def wake_up(self) -> None: async def wake_up(self, tags: Optional[list[str]] = None) -> None:
await self.engine_core.wake_up_async() await self.engine_core.wake_up_async(tags)
async def is_sleeping(self) -> bool: async def is_sleeping(self) -> bool:
return await self.engine_core.is_sleeping_async() return await self.engine_core.is_sleeping_async()
......
...@@ -264,8 +264,8 @@ class EngineCore: ...@@ -264,8 +264,8 @@ class EngineCore:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.model_executor.sleep(level) self.model_executor.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up() self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping return self.model_executor.is_sleeping
......
...@@ -92,7 +92,7 @@ class EngineCoreClient(ABC): ...@@ -92,7 +92,7 @@ class EngineCoreClient(ABC):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
...@@ -141,7 +141,7 @@ class EngineCoreClient(ABC): ...@@ -141,7 +141,7 @@ class EngineCoreClient(ABC):
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
raise NotImplementedError raise NotImplementedError
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
raise NotImplementedError raise NotImplementedError
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
...@@ -206,8 +206,8 @@ class InprocClient(EngineCoreClient): ...@@ -206,8 +206,8 @@ class InprocClient(EngineCoreClient):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
...@@ -520,8 +520,8 @@ class SyncMPClient(MPClient): ...@@ -520,8 +520,8 @@ class SyncMPClient(MPClient):
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.call_utility("sleep", level) self.call_utility("sleep", level)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
self.call_utility("wake_up") self.call_utility("wake_up", tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.call_utility("is_sleeping") return self.call_utility("is_sleeping")
...@@ -647,8 +647,8 @@ class AsyncMPClient(MPClient): ...@@ -647,8 +647,8 @@ class AsyncMPClient(MPClient):
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
await self.call_utility_async("sleep", level) await self.call_utility_async("sleep", level)
async def wake_up_async(self) -> None: async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
await self.call_utility_async("wake_up") await self.call_utility_async("wake_up", tags)
async def is_sleeping_async(self) -> bool: async def is_sleeping_async(self) -> bool:
return await self.call_utility_async("is_sleeping") return await self.call_utility_async("is_sleeping")
......
...@@ -245,8 +245,8 @@ class LLMEngine: ...@@ -245,8 +245,8 @@ class LLMEngine:
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.engine_core.sleep(level) self.engine_core.sleep(level)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
self.engine_core.wake_up() self.engine_core.wake_up(tags)
def is_sleeping(self) -> bool: def is_sleeping(self) -> bool:
return self.engine_core.is_sleeping() return self.engine_core.is_sleeping()
......
...@@ -83,9 +83,9 @@ class Worker(WorkerBase): ...@@ -83,9 +83,9 @@ class Worker(WorkerBase):
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes) used_bytes / GiB_bytes)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.wake_up() allocator.wake_up(tags)
def init_device(self): def init_device(self):
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
......
...@@ -135,9 +135,9 @@ class Worker(LocalOrDistributedWorkerBase): ...@@ -135,9 +135,9 @@ class Worker(LocalOrDistributedWorkerBase):
"%.2f GiB memory is still in use.", freed_bytes / GiB_bytes, "%.2f GiB memory is still in use.", freed_bytes / GiB_bytes,
used_bytes / GiB_bytes) used_bytes / GiB_bytes)
def wake_up(self) -> None: def wake_up(self, tags: Optional[list[str]] = None) -> None:
allocator = CuMemAllocator.get_instance() allocator = CuMemAllocator.get_instance()
allocator.wake_up() allocator.wake_up(tags=tags)
def init_device(self) -> None: def init_device(self) -> None:
if self.device_config.device.type == "cuda": if self.device_config.device.type == "cuda":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment