Unverified Commit b12cb383 authored by chunxiaozheng's avatar chunxiaozheng Committed by GitHub
Browse files

implements register kv caches in lmcache connector (#31397)


Signed-off-by: default avataridellzheng <idellzheng@tencent.com>
parent 5bc66411
...@@ -107,6 +107,22 @@ class LMCacheConnectorV1(KVConnectorBase_V1): ...@@ -107,6 +107,22 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
# ============================== # ==============================
# Worker-side methods # Worker-side methods
# ============================== # ==============================
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches: dictionary of layer names, kv cache
"""
if hasattr(self._lmcache_engine, "register_kv_caches"):
self._lmcache_engine.register_kv_caches(kv_caches)
else:
logger.warning(
"LMCache engine does not support register_kv_caches, "
"please check and use the latest version"
)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
""" """
Start loading the KV cache from the connector to vLLM's paged Start loading the KV cache from the connector to vLLM's paged
......
...@@ -782,6 +782,16 @@ class LMCacheConnectorV1Impl: ...@@ -782,6 +782,16 @@ class LMCacheConnectorV1Impl:
#################### ####################
# Worker side APIs # Worker side APIs
#################### ####################
@_lmcache_nvtx_annotate
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
logger.info("Registering KV caches")
# TODO(chunxiaozheng): `_init_kv_caches_from_forward_context` is
# not called, we should consider removing it.
assert len(self.kv_caches) == 0 and len(kv_caches) > 0
self.kv_caches = kv_caches
if self.lmcache_engine is not None:
kvcaches = list(self.kv_caches.values())
self.lmcache_engine.post_init(kvcaches=kvcaches)
@_lmcache_nvtx_annotate @_lmcache_nvtx_annotate
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment