Unverified Commit ed333105 authored by Roy Huang's avatar Roy Huang Committed by GitHub
Browse files

[KVConnector][LMCache] Propagate cache_salt through MP connector for per-user...


[KVConnector][LMCache] Propagate cache_salt through MP connector for per-user cache isolation (#39837)
Signed-off-by: default avatarroyyhuang <royyhuang@gmail.com>
Signed-off-by: default avatarroyyhuang <roy.y.huang@gmail.com>
parent 3cc328a4
...@@ -215,8 +215,11 @@ class LMCacheMPRequestTracker: ...@@ -215,8 +215,11 @@ class LMCacheMPRequestTracker:
# Main state # Main state
state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING state: LMCacheMPRequestState = LMCacheMPRequestState.PREFETCHING
cache_salt: str = ""
def __init__(self, request: "Request"): def __init__(self, request: "Request"):
self.request_id = request.request_id self.request_id = request.request_id
self.cache_salt: str = request.cache_salt or ""
self.all_token_ids = request.all_token_ids self.all_token_ids = request.all_token_ids
self.block_hashes = ConstantList(request.block_hashes) self.block_hashes = ConstantList(request.block_hashes)
self.allocated_block_ids = [] self.allocated_block_ids = []
...@@ -289,6 +292,7 @@ class LMCacheMPRequestMetadata: ...@@ -289,6 +292,7 @@ class LMCacheMPRequestMetadata:
request_id: str request_id: str
direction: Literal["STORE", "RETRIEVE"] direction: Literal["STORE", "RETRIEVE"]
op: LoadStoreOp op: LoadStoreOp
cache_salt: str = ""
@staticmethod @staticmethod
def GetStoreMetadata( def GetStoreMetadata(
...@@ -355,6 +359,7 @@ class LMCacheMPRequestMetadata: ...@@ -355,6 +359,7 @@ class LMCacheMPRequestMetadata:
request_id=tracker.request_id, request_id=tracker.request_id,
direction="STORE", direction="STORE",
op=op, op=op,
cache_salt=tracker.cache_salt,
) )
# Update the request tracker # Update the request tracker
...@@ -421,6 +426,7 @@ class LMCacheMPRequestMetadata: ...@@ -421,6 +426,7 @@ class LMCacheMPRequestMetadata:
request_id=tracker.request_id, request_id=tracker.request_id,
direction="RETRIEVE", direction="RETRIEVE",
op=op, op=op,
cache_salt=tracker.cache_salt,
) )
return ret return ret
...@@ -569,12 +575,14 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -569,12 +575,14 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids = [] request_ids = []
ops = [] ops = []
cache_salts = []
for meta in metadata.requests: for meta in metadata.requests:
if meta.direction != "RETRIEVE": if meta.direction != "RETRIEVE":
continue continue
request_ids.append(meta.request_id) request_ids.append(meta.request_id)
ops.append(meta.op) ops.append(meta.op)
cache_salts.append(meta.cache_salt)
if len(request_ids) == 0: if len(request_ids) == 0:
return return
...@@ -583,7 +591,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -583,7 +591,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
event = torch.cuda.Event(interprocess=True) event = torch.cuda.Event(interprocess=True)
event.record() event.record()
self.worker_adapter.batched_submit_retrieve_requests(request_ids, ops, event) self.worker_adapter.batched_submit_retrieve_requests(
request_ids, ops, event, cache_salts=cache_salts
)
def wait_for_layer_load(self, layer_name: str) -> None: def wait_for_layer_load(self, layer_name: str) -> None:
""" """
...@@ -640,11 +650,13 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -640,11 +650,13 @@ class LMCacheMPConnector(KVConnectorBase_V1):
request_ids = [] request_ids = []
ops = [] ops = []
cache_salts = []
for meta in metadata.requests: for meta in metadata.requests:
if meta.direction != "STORE": if meta.direction != "STORE":
continue continue
request_ids.append(meta.request_id) request_ids.append(meta.request_id)
ops.append(meta.op) ops.append(meta.op)
cache_salts.append(meta.cache_salt)
if len(request_ids) == 0: if len(request_ids) == 0:
return return
...@@ -653,7 +665,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -653,7 +665,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
event = torch.cuda.Event(interprocess=True) event = torch.cuda.Event(interprocess=True)
event.record() event.record()
self.worker_adapter.batched_submit_store_requests(request_ids, ops, event) self.worker_adapter.batched_submit_store_requests(
request_ids, ops, event, cache_salts=cache_salts
)
def get_finished( def get_finished(
self, finished_req_ids: set[str] self, finished_req_ids: set[str]
...@@ -755,6 +769,7 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -755,6 +769,7 @@ class LMCacheMPConnector(KVConnectorBase_V1):
self.scheduler_adapter.maybe_submit_lookup_request( self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, request.request_id,
token_ids=list(request.all_token_ids), token_ids=list(request.all_token_ids),
cache_salt=tracker.cache_salt,
) )
ret = self.scheduler_adapter.check_lookup_result(request.request_id) ret = self.scheduler_adapter.check_lookup_result(request.request_id)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment