Unverified Commit adb31506 authored by Tova Movshovitz's avatar Tova Movshovitz Committed by GitHub
Browse files

[KVConnector][Feature] Support KV connector cache reset via /reset_prefix_cache (#27170)


Signed-off-by: default avatartovam <tovam@pliops.com>
Signed-off-by: default avatarTova Movshovitz <tovam@pliops.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 4e26d3b0
...@@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC): ...@@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC):
expose connector transfer stats via Prometheus. expose connector transfer stats via Prometheus.
""" """
return None return None
def reset_cache(self) -> bool | None:
"""
Reset the connector's internal cache.
Returns:
bool: True if the cache was successfully reset, False otherwise.
"""
logger.debug(
"Connector cache reset requested, but %s does not implement reset_cache().",
type(self).__name__,
)
return None
...@@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1): ...@@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1):
per_engine_labelvalues, per_engine_labelvalues,
prom_metrics, prom_metrics,
) )
def reset_cache(self) -> bool:
results = [c.reset_cache() is not False for c in self._connectors]
return all(results)
...@@ -116,8 +116,10 @@ class EngineClient(ABC): ...@@ -116,8 +116,10 @@ class EngineClient(ABC):
... ...
@abstractmethod @abstractmethod
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: async def reset_prefix_cache(
"""Reset the prefix cache""" self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the prefix cache and optionally any configured connector cache"""
... ...
@abstractmethod @abstractmethod
......
...@@ -1491,8 +1491,12 @@ class LLM: ...@@ -1491,8 +1491,12 @@ class LLM:
def stop_profile(self) -> None: def stop_profile(self) -> None:
self.llm_engine.stop_profile() self.llm_engine.stop_profile()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
return self.llm_engine.reset_prefix_cache(reset_running_requests) self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.llm_engine.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
""" """
......
...@@ -663,14 +663,27 @@ if envs.VLLM_SERVER_DEV_MODE: ...@@ -663,14 +663,27 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/reset_prefix_cache") @router.post("/reset_prefix_cache")
async def reset_prefix_cache( async def reset_prefix_cache(
raw_request: Request, reset_running_requests: bool = Query(default=False) raw_request: Request,
reset_running_requests: bool = Query(default=False),
reset_external: bool = Query(default=False),
): ):
""" """
Reset the prefix cache. Note that we currently do not check if the Reset the local prefix cache.
prefix cache is successfully reset in the API server.
Optionally, if the query parameter `reset_external=true`
also resets the external (connector-managed) prefix cache.
Note that we currently do not check if the prefix cache
is successfully reset in the API server.
Example:
POST /reset_prefix_cache?reset_external=true
""" """
logger.info("Resetting prefix cache...") logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache(reset_running_requests)
await engine_client(raw_request).reset_prefix_cache(
reset_running_requests, reset_external
)
return Response(status_code=200) return Response(status_code=200)
@router.post("/reset_mm_cache") @router.post("/reset_mm_cache")
......
...@@ -152,7 +152,9 @@ class SchedulerInterface(ABC): ...@@ -152,7 +152,9 @@ class SchedulerInterface(ABC):
return self.has_unfinished_requests() or self.has_finished_requests() return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod @abstractmethod
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the prefix cache for KV cache. """Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated. This is particularly required when the model weights are live-updated.
......
...@@ -1380,7 +1380,9 @@ class Scheduler(SchedulerInterface): ...@@ -1380,7 +1380,9 @@ class Scheduler(SchedulerInterface):
def has_finished_requests(self) -> bool: def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0 return len(self.finished_req_ids) > 0
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the KV prefix cache. """Reset the KV prefix cache.
If reset_running_requests is True, all the running requests will be If reset_running_requests is True, all the running requests will be
...@@ -1418,8 +1420,26 @@ class Scheduler(SchedulerInterface): ...@@ -1418,8 +1420,26 @@ class Scheduler(SchedulerInterface):
"the presence of running requests waiting for remote KV transfer, " "the presence of running requests waiting for remote KV transfer, "
"which is not supported yet." "which is not supported yet."
) )
if reset_connector:
reset_successful = self.reset_connector_cache() and reset_successful
return reset_successful return reset_successful
def reset_connector_cache(self) -> bool:
if self.connector is None:
logger.warning("reset_connector called but no KV connector is configured.")
return False
if self.connector.reset_cache() is False:
return False
if self.log_stats:
assert self.connector_prefix_cache_stats is not None
self.connector_prefix_cache_stats.reset = True
return True
def make_stats( def make_stats(
self, self,
spec_decoding_stats: SpecDecodingStats | None = None, spec_decoding_stats: SpecDecodingStats | None = None,
......
...@@ -749,8 +749,12 @@ class AsyncLLM(EngineClient): ...@@ -749,8 +749,12 @@ class AsyncLLM(EngineClient):
self.input_processor.clear_mm_cache() self.input_processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async() await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: async def reset_prefix_cache(
return await self.engine_core.reset_prefix_cache_async(reset_running_requests) self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return await self.engine_core.reset_prefix_cache_async(
reset_running_requests, reset_connector
)
async def sleep(self, level: int = 1) -> None: async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache() await self.reset_prefix_cache()
......
...@@ -503,8 +503,12 @@ class EngineCore: ...@@ -503,8 +503,12 @@ class EngineCore:
self.model_executor.reset_mm_cache() self.model_executor.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
return self.scheduler.reset_prefix_cache(reset_running_requests) self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.scheduler.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.model_executor.sleep(level) self.model_executor.sleep(level)
......
...@@ -138,7 +138,9 @@ class EngineCoreClient(ABC): ...@@ -138,7 +138,9 @@ class EngineCoreClient(ABC):
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
raise NotImplementedError raise NotImplementedError
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
raise NotImplementedError raise NotImplementedError
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
...@@ -209,7 +211,7 @@ class EngineCoreClient(ABC): ...@@ -209,7 +211,7 @@ class EngineCoreClient(ABC):
raise NotImplementedError raise NotImplementedError
async def reset_prefix_cache_async( async def reset_prefix_cache_async(
self, reset_running_requests: bool = False self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool: ) -> bool:
raise NotImplementedError raise NotImplementedError
...@@ -289,8 +291,12 @@ class InprocClient(EngineCoreClient): ...@@ -289,8 +291,12 @@ class InprocClient(EngineCoreClient):
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache() self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
return self.engine_core.reset_prefix_cache(reset_running_requests) self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1) -> None: def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level) self.engine_core.sleep(level)
...@@ -753,8 +759,12 @@ class SyncMPClient(MPClient): ...@@ -753,8 +759,12 @@ class SyncMPClient(MPClient):
def reset_mm_cache(self) -> None: def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache") self.call_utility("reset_mm_cache")
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
return self.call_utility("reset_prefix_cache", reset_running_requests) self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.call_utility(
"reset_prefix_cache", reset_running_requests, reset_connector
)
def add_lora(self, lora_request: LoRARequest) -> bool: def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request) return self.call_utility("add_lora", lora_request)
...@@ -958,10 +968,10 @@ class AsyncMPClient(MPClient): ...@@ -958,10 +968,10 @@ class AsyncMPClient(MPClient):
await self.call_utility_async("reset_mm_cache") await self.call_utility_async("reset_mm_cache")
async def reset_prefix_cache_async( async def reset_prefix_cache_async(
self, reset_running_requests: bool = False self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool: ) -> bool:
return await self.call_utility_async( return await self.call_utility_async(
"reset_prefix_cache", reset_running_requests "reset_prefix_cache", reset_running_requests, reset_connector
) )
async def sleep_async(self, level: int = 1) -> None: async def sleep_async(self, level: int = 1) -> None:
......
...@@ -328,8 +328,12 @@ class LLMEngine: ...@@ -328,8 +328,12 @@ class LLMEngine:
self.input_processor.clear_mm_cache() self.input_processor.clear_mm_cache()
self.engine_core.reset_mm_cache() self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool: def reset_prefix_cache(
return self.engine_core.reset_prefix_cache(reset_running_requests) self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1): def sleep(self, level: int = 1):
self.engine_core.sleep(level) self.engine_core.sleep(level)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment