Unverified Commit adb31506 authored by Tova Movshovitz's avatar Tova Movshovitz Committed by GitHub
Browse files

[KVConnector][Feature] Support KV connector cache reset via /reset_prefix_cache (#27170)


Signed-off-by: default avatartovam <tovam@pliops.com>
Signed-off-by: default avatarTova Movshovitz <tovam@pliops.com>
Co-authored-by: default avatargemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
parent 4e26d3b0
......@@ -573,3 +573,17 @@ class KVConnectorBase_V1(ABC):
expose connector transfer stats via Prometheus.
"""
return None
def reset_cache(self) -> bool | None:
"""
Reset the connector's internal cache.
Returns:
bool: True if the cache was successfully reset, False otherwise.
"""
logger.debug(
"Connector cache reset requested, but %s does not implement reset_cache().",
type(self).__name__,
)
return None
......@@ -452,3 +452,7 @@ class MultiConnector(KVConnectorBase_V1):
per_engine_labelvalues,
prom_metrics,
)
def reset_cache(self) -> bool:
results = [c.reset_cache() is not False for c in self._connectors]
return all(results)
......@@ -116,8 +116,10 @@ class EngineClient(ABC):
...
@abstractmethod
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
"""Reset the prefix cache"""
async def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the prefix cache and optionally any configured connector cache"""
...
@abstractmethod
......
......@@ -1491,8 +1491,12 @@ class LLM:
def stop_profile(self) -> None:
self.llm_engine.stop_profile()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.llm_engine.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.llm_engine.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
"""
......
......@@ -663,14 +663,27 @@ if envs.VLLM_SERVER_DEV_MODE:
@router.post("/reset_prefix_cache")
async def reset_prefix_cache(
raw_request: Request, reset_running_requests: bool = Query(default=False)
raw_request: Request,
reset_running_requests: bool = Query(default=False),
reset_external: bool = Query(default=False),
):
"""
Reset the prefix cache. Note that we currently do not check if the
prefix cache is successfully reset in the API server.
Reset the local prefix cache.
Optionally, if the query parameter `reset_external=true`
also resets the external (connector-managed) prefix cache.
Note that we currently do not check if the prefix cache
is successfully reset in the API server.
Example:
POST /reset_prefix_cache?reset_external=true
"""
logger.info("Resetting prefix cache...")
await engine_client(raw_request).reset_prefix_cache(reset_running_requests)
await engine_client(raw_request).reset_prefix_cache(
reset_running_requests, reset_external
)
return Response(status_code=200)
@router.post("/reset_mm_cache")
......
......@@ -152,7 +152,9 @@ class SchedulerInterface(ABC):
return self.has_unfinished_requests() or self.has_finished_requests()
@abstractmethod
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the prefix cache for KV cache.
This is particularly required when the model weights are live-updated.
......
......@@ -1380,7 +1380,9 @@ class Scheduler(SchedulerInterface):
def has_finished_requests(self) -> bool:
return len(self.finished_req_ids) > 0
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
"""Reset the KV prefix cache.
If reset_running_requests is True, all the running requests will be
......@@ -1418,8 +1420,26 @@ class Scheduler(SchedulerInterface):
"the presence of running requests waiting for remote KV transfer, "
"which is not supported yet."
)
if reset_connector:
reset_successful = self.reset_connector_cache() and reset_successful
return reset_successful
def reset_connector_cache(self) -> bool:
if self.connector is None:
logger.warning("reset_connector called but no KV connector is configured.")
return False
if self.connector.reset_cache() is False:
return False
if self.log_stats:
assert self.connector_prefix_cache_stats is not None
self.connector_prefix_cache_stats.reset = True
return True
def make_stats(
self,
spec_decoding_stats: SpecDecodingStats | None = None,
......
......@@ -749,8 +749,12 @@ class AsyncLLM(EngineClient):
self.input_processor.clear_mm_cache()
await self.engine_core.reset_mm_cache_async()
async def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return await self.engine_core.reset_prefix_cache_async(reset_running_requests)
async def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return await self.engine_core.reset_prefix_cache_async(
reset_running_requests, reset_connector
)
async def sleep(self, level: int = 1) -> None:
await self.reset_prefix_cache()
......
......@@ -503,8 +503,12 @@ class EngineCore:
self.model_executor.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.scheduler.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.scheduler.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
self.model_executor.sleep(level)
......
......@@ -138,7 +138,9 @@ class EngineCoreClient(ABC):
def reset_mm_cache(self) -> None:
raise NotImplementedError
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
raise NotImplementedError
def sleep(self, level: int = 1) -> None:
......@@ -209,7 +211,7 @@ class EngineCoreClient(ABC):
raise NotImplementedError
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
raise NotImplementedError
......@@ -289,8 +291,12 @@ class InprocClient(EngineCoreClient):
def reset_mm_cache(self) -> None:
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.engine_core.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1) -> None:
self.engine_core.sleep(level)
......@@ -753,8 +759,12 @@ class SyncMPClient(MPClient):
def reset_mm_cache(self) -> None:
self.call_utility("reset_mm_cache")
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.call_utility("reset_prefix_cache", reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.call_utility(
"reset_prefix_cache", reset_running_requests, reset_connector
)
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.call_utility("add_lora", lora_request)
......@@ -958,10 +968,10 @@ class AsyncMPClient(MPClient):
await self.call_utility_async("reset_mm_cache")
async def reset_prefix_cache_async(
self, reset_running_requests: bool = False
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return await self.call_utility_async(
"reset_prefix_cache", reset_running_requests
"reset_prefix_cache", reset_running_requests, reset_connector
)
async def sleep_async(self, level: int = 1) -> None:
......
......@@ -328,8 +328,12 @@ class LLMEngine:
self.input_processor.clear_mm_cache()
self.engine_core.reset_mm_cache()
def reset_prefix_cache(self, reset_running_requests: bool = False) -> bool:
return self.engine_core.reset_prefix_cache(reset_running_requests)
def reset_prefix_cache(
self, reset_running_requests: bool = False, reset_connector: bool = False
) -> bool:
return self.engine_core.reset_prefix_cache(
reset_running_requests, reset_connector
)
def sleep(self, level: int = 1):
self.engine_core.sleep(level)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment