Unverified Commit 7206ce4c authored by Cody Yu's avatar Cody Yu Committed by GitHub
Browse files

[Core] Support `reset_prefix_cache` (#12284)

parent 96f6a759
......@@ -285,6 +285,33 @@ class KVCacheManager:
if block.ref_cnt == 0:
self.free_block_queue.append(block)
def reset_prefix_cache(self) -> bool:
"""Reset prefix cache. This function may be used in RLHF
flows to invalid prefix caching after the weights are updated,
or used for resetting prefix caching status for benchmarking.
Returns:
bool: True if the prefix cache is successfully reset,
False otherwise.
"""
num_used_blocks = (self.num_gpu_blocks -
self.free_block_queue.num_free_blocks)
if num_used_blocks > 0:
logger.warning(
"Failed to reset prefix cache because some "
"blocks (%d) are not freed yet", num_used_blocks)
return False
# Remove all hashes so that no new blocks will hit.
self.cached_block_hash_to_block = defaultdict(dict)
# Remove all hashes from all blocks.
for block in self.block_pool:
block.reset_hash()
logger.info("Successfully reset prefix cache")
return True
def get_num_common_prefix_blocks(
self,
request: Request,
......
......@@ -529,6 +529,9 @@ class Scheduler:
def has_unfinished_requests(self) -> bool:
return self.get_num_unfinished_requests() > 0
def reset_prefix_cache(self) -> bool:
return self.kv_cache_manager.reset_prefix_cache()
def make_stats(self) -> SchedulerStats:
return SchedulerStats(
num_running_reqs=len(self.running),
......
......@@ -66,6 +66,11 @@ class EngineCoreProfile:
is_start: bool
@dataclass
class EngineCoreResetPrefixCache:
pass
class EngineCoreRequestType(enum.Enum):
"""
Request types defined as hex byte strings, so it can be sent over sockets
......@@ -74,6 +79,8 @@ class EngineCoreRequestType(enum.Enum):
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile, List[str]]
EngineCoreRequestUnion = Union[EngineCoreRequest, EngineCoreProfile,
EngineCoreResetPrefixCache, List[str]]
......@@ -321,6 +321,9 @@ class AsyncLLM(EngineClient):
async def stop_profile(self) -> None:
await self.engine_core.profile_async(False)
async def reset_prefix_cache(self) -> None:
await self.engine_core.reset_prefix_cache_async()
@property
def is_running(self) -> bool:
return True
......
......@@ -20,7 +20,7 @@ from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
......@@ -135,6 +135,9 @@ class EngineCore:
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
......@@ -247,6 +250,8 @@ class EngineCoreProc(EngineCore):
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
......@@ -271,7 +276,9 @@ class EngineCoreProc(EngineCore):
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type == EngineCoreRequestType.PROFILE.value:
elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")
......
......@@ -14,7 +14,7 @@ from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion)
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import PickleEncoder
......@@ -69,6 +69,9 @@ class EngineCoreClient(ABC):
def profile(self, is_start: bool = True) -> None:
raise NotImplementedError
def reset_prefix_cache(self) -> None:
raise NotImplementedError
def abort_requests(self, request_ids: List[str]) -> None:
raise NotImplementedError
......@@ -81,6 +84,9 @@ class EngineCoreClient(ABC):
async def profile_async(self, is_start: bool = True) -> None:
raise NotImplementedError
async def reset_prefix_cache_async(self) -> None:
raise NotImplementedError
async def abort_requests_async(self, request_ids: List[str]) -> None:
raise NotImplementedError
......@@ -108,12 +114,15 @@ class InprocClient(EngineCoreClient):
if len(request_ids) > 0:
self.engine_core.abort_requests(request_ids)
def shutdown(self):
def shutdown(self) -> None:
self.engine_core.shutdown()
def profile(self, is_start: bool = True) -> None:
self.engine_core.profile(is_start)
def reset_prefix_cache(self) -> None:
self.engine_core.reset_prefix_cache()
class MPClient(EngineCoreClient):
"""
......@@ -229,6 +238,10 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
class AsyncMPClient(MPClient):
"""Asyncio-compatible client for multi-proc EngineCore."""
......@@ -266,3 +279,7 @@ class AsyncMPClient(MPClient):
async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE,
EngineCoreProfile(is_start))
async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE,
EngineCoreResetPrefixCache())
......@@ -162,6 +162,9 @@ class LLMEngine:
def stop_profile(self):
self.engine_core.profile(False)
def reset_prefix_cache(self):
self.engine_core.reset_prefix_cache()
def get_tokenizer_group(
self,
group_type: Type[_G] = BaseTokenizerGroup,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment