Unverified Commit ed333105 authored by Roy Huang's avatar Roy Huang Committed by GitHub
Browse files

[KVConnector][LMCache] Propagate cache_salt through MP connector for per-user...


[KVConnector][LMCache] Propagate cache_salt through MP connector for per-user cache isolation (#39837)
Signed-off-by: default avatarroyyhuang <royyhuang@gmail.com>
Signed-off-by: default avatarroyyhuang <roy.y.huang@gmail.com>
parent 3cc328a4
......@@ -215,8 +215,11 @@ class LMCacheMPRequestTracker:
# Main state
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
cache_salt: str = ""
def __init__(self, request: "Request"):
self.request_id = request.request_id
self.cache_salt: str = request.cache_salt or ""
self.all_token_ids = request.all_token_ids
self.block_hashes = ConstantList(request.block_hashes)
self.allocated_block_ids = []
......@@ -289,6 +292,7 @@ class LMCacheMPRequestMetadata:
request_id: str
direction: Literal["STORE", "RETRIEVE"]
op: LoadStoreOp
cache_salt: str = ""
@staticmethod
def GetStoreMetadata(
......@@ -355,6 +359,7 @@ class LMCacheMPRequestMetadata:
request_id=tracker.request_id,
direction="STORE",
op=op,
cache_salt=tracker.cache_salt,
)
# Update the request tracker
......@@ -421,6 +426,7 @@ class LMCacheMPRequestMetadata:
request_id=tracker.request_id,
direction="RETRIEVE",
op=op,
cache_salt=tracker.cache_salt,
)
return ret
......@@ -569,12 +575,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids = []
ops = []
cache_salts = []
for meta in metadata.requests:
if meta.direction != "RETRIEVE":
continue
request_ids.append(meta.request_id)
ops.append(meta.op)
cache_salts.append(meta.cache_salt)
if len(request_ids) == 0:
return
......@@ -583,7 +591,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event)
self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event, cache_salts=cache_salts
)
def wait_for_layer_load(self, layer_name: str) -> None:
"""
......@@ -640,11 +650,13 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids = []
ops = []
cache_salts = []
for meta in metadata.requests:
if meta.direction != "STORE":
continue
request_ids.append(meta.request_id)
ops.append(meta.op)
cache_salts.append(meta.cache_salt)
if len(request_ids) == 0:
return
......@@ -653,7 +665,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
event = torch.cuda.Event(interprocess=True)
event.record()
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event)
self.worker_adapter.batched_submit_store_requests(
request_ids, ops, event, cache_salts=cache_salts
)
def get_finished(
self, finished_req_ids: set[str]
......@@ -755,6 +769,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id,
token_ids=list(request.all_token_ids),
cache_salt=tracker.cache_salt,
)
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment