Unverified Commit 7206ce4c authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Core] Support `reset_prefix_cache` (#12284)

parent 96f6a759
...@@ -285,6 +285,33 @@ class KVCacheManager: ...@@ -285,6 +285,33 @@ class KVCacheManager:
if block.ref_cnt == 0: if block.ref_cnt == 0:
self.free_block_queue.append(block) self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks -
self.free_block_queue.num_free_blocks)
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.block_pool:
block.reset_hash()
logger.info("Successfully reset prefix cache")
return True
def get_num_common_prefix_blocks( def get_num_common_prefix_blocks(
self, self,
request: Request, request: Request,
......
...@@ -529,6 +529,9 @@ class Scheduler: ...@@ -529,6 +529,9 @@ class Scheduler:
def has_unfinished_requests(self) -> bool: def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0 return self.get_num_unfinished_requests() > 0
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> SchedulerStats: def make_stats(self) -> SchedulerStats:
return SchedulerStats( return SchedulerStats(
num_running_reqs=len(self.running), num_running_reqs=len(self.running),
......
...@@ -66,6 +66,11 @@ class EngineCoreProfile: ...@@ -66,6 +66,11 @@ class EngineCoreProfile:
is_start: bool is_start: bool
@dataclass
class EngineCoreResetPrefixCache:
pass
class EngineCoreRequestType(enum.Enum): class EngineCoreRequestType(enum.Enum):
""" """
Request types defined as hex byte strings, so it can be sent over sockets Request types defined as hex byte strings, so it can be sent over sockets
...@@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum): ...@@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00' ADD = b'\x00'
ABORT = b'\x01' ABORT = b'\x01'
PROFILE = b'\x02' PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]] EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreResetPrefixCache, List[str]]
...@@ -321,6 +321,9 @@ class AsyncLLM(EngineClient): ...@@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None: async def stop_profile(self) -> None:
await self.engine_core.profile_async(False) await self.engine_core.profile_async(False)
async def reset_prefix_cache(self) -> None:
await self.engine_core.reset_prefix_cache_async()
@property @property
def is_running(self) -> bool: def is_running(self) -> bool:
return True return True
......
...@@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config ...@@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType, EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion) EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
...@@ -135,6 +135,9 @@ class EngineCore: ...@@ -135,6 +135,9 @@ class EngineCore:
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
self.model_executor.profile(is_start) self.model_executor.profile(is_start)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
class EngineCoreProc(EngineCore): class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process.""" """ZMQ-wrapper for running EngineCore in background process."""
...@@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore): ...@@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
self.add_request(request) self.add_request(request)
elif isinstance(request, EngineCoreProfile): elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start) self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else: else:
# TODO: make an EngineCoreAbort wrapper # TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list) assert isinstance(request, list)
...@@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore): ...@@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
request = decoder_add_req.decode(request_data) request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value: elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data) request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value: elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data) request = pickle.loads(request_data)
else: else:
raise ValueError(f"Unknown RequestType: {request_type}") raise ValueError(f"Unknown RequestType: {request_type}")
......
...@@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree, ...@@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket) make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType, EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion) EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder from vllm.v1.serial_utils import PickleEncoder
...@@ -69,6 +69,9 @@ class EngineCoreClient(ABC): ...@@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
raise NotImplementedError raise NotImplementedError
def reset_prefix_cache(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None: def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -81,6 +84,9 @@ class EngineCoreClient(ABC): ...@@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError raise NotImplementedError
async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError raise NotImplementedError
...@@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient): ...@@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
if len(request_ids) > 0: if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids) self.engine_core.abort_requests(request_ids)
def shutdown(self): def shutdown(self) -> None:
self.engine_core.shutdown() self.engine_core.shutdown()
def profile(self, is_start: bool = True) -> None: def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start) self.engine_core.profile(is_start)
def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()
class MPClient(EngineCoreClient): class MPClient(EngineCoreClient):
""" """
...@@ -229,6 +238,10 @@ class SyncMPClient(MPClient): ...@@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.PROFILE, self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start)) EngineCoreProfile(is_start))
def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
class AsyncMPClient(MPClient): class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore.""" """Asyncio-compatible client for multi-proc EngineCore."""
...@@ -266,3 +279,7 @@ class AsyncMPClient(MPClient): ...@@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
async def profile_async(self, is_start: bool = True) -> None: async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE, await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start)) EngineCoreProfile(is_start))
async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
...@@ -162,6 +162,9 @@ class LLMEngine: ...@@ -162,6 +162,9 @@ class LLMEngine:
def stop_profile(self): def stop_profile(self):
self.engine_core.profile(False) self.engine_core.profile(False)
def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache()
def get_tokenizer_group( def get_tokenizer_group(
self, self,
group_type: Type[_G] = BaseTokenizerGroup, group_type: Type[_G] = BaseTokenizerGroup,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment