Unverified Commit e94ec597 authored by Yuwei An's avatar Yuwei An Committed by GitHub
Browse files

[LMCache] Token Base IPC API (#34175)


Signed-off-by: default avatarOasis-Git <ayw.sirius19@gmail.com>
parent 13397841
...@@ -20,16 +20,42 @@ from lmcache.v1.multiprocess.protocol import RequestType, get_response_class ...@@ -20,16 +20,42 @@ from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
logger = init_logger(__name__) logger = init_logger(__name__)
def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache: def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache:
logger.info("KV caches keys are %s", list(kv_caches.keys())) logger.info("KV caches keys are %s", list(kv_caches.keys()))
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()] return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
def striding_block_hashes(
block_hashes: list[bytes], blocks_in_chunk: int
) -> Iterable[bytes]:
"""Extract chunk-level hashes from block hashes by striding.
In hash-based vLLM, each vLLM block has its own hash. LMCache chunks
span ``blocks_in_chunk`` consecutive blocks. The representative hash
for a chunk is the hash of the **last** block in that chunk (because
each block hash already encodes its prefix). So we start at index
``blocks_in_chunk - 1`` and stride by ``blocks_in_chunk``.
"""
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
def send_lmcache_request( def send_lmcache_request(
mq_client: MessageQueueClient, mq_client: MessageQueueClient,
request_type: RequestType, request_type: RequestType,
payloads: list[Any], payloads: list[Any],
) -> MessagingFuture[Any]: ) -> MessagingFuture[Any]:
"""
Helper function to send the request to the LMCache multiprocess server
Args:
mq_client: The LMCache multiprocess mode message queue client
request_type: The request type
payloads: The request payloads
Returns:
A messaging future for the request
"""
future = mq_client.submit_request( future = mq_client.submit_request(
request_type, payloads, get_response_class(request_type) request_type, payloads, get_response_class(request_type)
) )
...@@ -39,40 +65,44 @@ def send_lmcache_request( ...@@ -39,40 +65,44 @@ def send_lmcache_request(
def get_lmcache_chunk_size( def get_lmcache_chunk_size(
mq_client: MessageQueueClient, mq_client: MessageQueueClient,
) -> int: ) -> int:
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, []) """
chunk_size = future.result() Helper function to get the LMCache chunk size from the server
return chunk_size
Args:
mq_client: The LMCache multiprocess mode message queue client
def striding_block_hashes( Returns:
block_hashes: list[bytes], An integer representing the LMCache chunk size
blocks_in_chunk,
) -> Iterable[bytes]:
"""Striding the block hashes to get the block hashes for each chunk.
For example, if blocks_in_chunk is 16, then we will get the block hashes
for the 16th, 32nd, 48th, ... blocks.
""" """
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk) future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
chunk_size = future.result()
return chunk_size
@dataclass @dataclass
class LoadStoreOp: class LoadStoreOp:
block_hashes: list[bytes]
block_ids: list[int] block_ids: list[int]
"""Block ids for the load/store operation"""
def __len__(self) -> int: token_ids: list[int] | None = None
return len(self.block_hashes) """Token IDs for the load/store operation (token mode)"""
def __post_init__(self): block_hashes: list[bytes] | None = None
assert len(self.block_hashes) == len(self.block_ids), ( """Block hashes for the load/store operation (hash mode)"""
"The number of block hashes should be equal to the number of block ids "
f"But got {len(self.block_hashes)} and {len(self.block_ids)}" start: int = 0
) """Start token index (token mode only)"""
end: int = 0
"""End token index (token mode only)"""
def __len__(self) -> int:
return len(self.block_ids)
StoreResult = bool StoreResult = bool
RetrieveResult = list[bool] RetrieveResult = list[bool]
LookupResult = list[bool] LookupResult = int
class LMCacheMPSchedulerAdapter: class LMCacheMPSchedulerAdapter:
...@@ -95,10 +125,6 @@ class LMCacheMPSchedulerAdapter: ...@@ -95,10 +125,6 @@ class LMCacheMPSchedulerAdapter:
kv_rank: The kv rank used for LMCache keys kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM vllm_block_size: The block size used in vLLM
""" """
logger.warning(
"Importing LMCacheMPSchedulerAdapter is deprecated. "
"Please update your LMCache to the latest version."
)
self.mq_client = MessageQueueClient(server_url, context) self.mq_client = MessageQueueClient(server_url, context)
# Request futures # Request futures
...@@ -116,22 +142,89 @@ class LMCacheMPSchedulerAdapter: ...@@ -116,22 +142,89 @@ class LMCacheMPSchedulerAdapter:
self.blocks_in_chunk = self.chunk_size // vllm_block_size self.blocks_in_chunk = self.chunk_size // vllm_block_size
@_lmcache_nvtx_annotate @_lmcache_nvtx_annotate
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]): def maybe_submit_lookup_request(
self,
request_id: str,
block_hashes: list[bytes] | None = None,
token_ids: list[int] | None = None,
) -> None:
"""
Submit a new lookup request to LMCache if there is no ongoing request.
Supports both token-based and hash-based vLLM:
- token_ids: token IDs (token-based vLLM) -> single token-mode key
- block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys
Exactly one of block_hashes or token_ids must be provided.
Args:
request_id: The ID of the lookup request. The same ID indicates it's
from the same request
block_hashes: Block hashes to lookup from LMCache (hash mode)
token_ids: Token IDs to lookup from LMCache (token mode)
Returns:
None
Notes:
This function will have a side-effect: submitting a look up request to
LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
for later retrieve operations.
In the meantime, this function will record the lookup request, and the
status of the look up request can be checked by `check_lookup_result`.
"""
if request_id in self.lookup_futures: if request_id in self.lookup_futures:
# Skip if there is already a lookup request # Skip if there is already a lookup request
return return
s = striding_block_hashes(block_hashes, self.blocks_in_chunk) assert (block_hashes is None) != (token_ids is None), (
keys = [self._create_key(block_hash) for block_hash in s] "Exactly one of block_hashes or token_ids must be provided"
)
if block_hashes is not None:
# Hash mode: stride block hashes -> N hash-mode keys
chunk_hashes = list(
striding_block_hashes(block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
]
else:
# Token mode: truncate to chunk-aligned length
assert token_ids is not None
aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size
if aligned_end == 0:
return
keys = [
self._create_key(
token_ids,
start=0,
end=aligned_end,
request_id=request_id,
).no_worker_id_version()
]
future = send_lmcache_request( future = send_lmcache_request(
self.mq_client, self.mq_client,
RequestType.LOOKUP, RequestType.LOOKUP,
[keys, True], [keys],
) )
self.lookup_futures[request_id] = future self.lookup_futures[request_id] = future
@_lmcache_nvtx_annotate @_lmcache_nvtx_annotate
def check_lookup_result(self, request_id: str) -> int | None: def check_lookup_result(self, request_id: str) -> int | None:
"""
Check the result of a previously submitted lookup request.
Args:
request_id: The ID of the lookup request submitted in
`maybe_submit_lookup_request`
Returns:
An integer representing the total number of tokens matched
in LMCache (prefix matching), or
None if the lookup request is not finished yet.
"""
assert request_id in self.lookup_futures, ( assert request_id in self.lookup_futures, (
f"Lookup request for request_id={request_id} has not been submitted" f"Lookup request for request_id={request_id} has not been submitted"
) )
...@@ -141,7 +234,7 @@ class LMCacheMPSchedulerAdapter: ...@@ -141,7 +234,7 @@ class LMCacheMPSchedulerAdapter:
return None return None
result = future.result() result = future.result()
num_chunks = sum(result) num_chunks = result
return num_chunks * self.chunk_size return num_chunks * self.chunk_size
def num_blocks_per_chunk(self) -> int: def num_blocks_per_chunk(self) -> int:
...@@ -159,14 +252,47 @@ class LMCacheMPSchedulerAdapter: ...@@ -159,14 +252,47 @@ class LMCacheMPSchedulerAdapter:
""" """
self.lookup_futures.pop(request_id, None) self.lookup_futures.pop(request_id, None)
def end_session(self, request_id: str) -> None:
"""
Notify LMCache server to remove the session for a finished request.
Args:
request_id: The ID of the finished request.
"""
send_lmcache_request(
self.mq_client,
RequestType.END_SESSION,
[request_id],
)
# Helper functions # Helper functions
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey: def _create_key(
"""Convert a block hash to an IPC cache engine key""" self,
token_ids: list[int],
start: int = 0,
end: int = 0,
request_id: str | None = None,
) -> IPCCacheEngineKey:
"""Convert token IDs to an IPC cache engine key"""
return IPCCacheEngineKey( return IPCCacheEngineKey(
model_name=self.model_name, model_name=self.model_name,
world_size=self.world_size, world_size=self.world_size,
worker_id=self.worker_id, worker_id=self.worker_id,
chunk_hash=block_hash, token_ids=tuple(token_ids),
start=start,
end=end,
request_id=request_id,
)
def _create_hash_key(
self, chunk_hash: bytes, request_id: str | None = None
) -> IPCCacheEngineKey:
"""Create a hash-mode IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=None,
chunk_hash=chunk_hash,
request_id=request_id,
) )
...@@ -180,10 +306,6 @@ class LMCacheMPWorkerAdapter: ...@@ -180,10 +306,6 @@ class LMCacheMPWorkerAdapter:
kv_rank: int, kv_rank: int,
vllm_block_size: int, vllm_block_size: int,
): ):
logger.warning(
"Importing LMCacheMPWorkerAdapter is deprecated. "
"Please update your LMCache to the latest version."
)
self.mq_client = MessageQueueClient(server_url, context) self.mq_client = MessageQueueClient(server_url, context)
# Instance id for GPU worker # Instance id for GPU worker
...@@ -201,7 +323,10 @@ class LMCacheMPWorkerAdapter: ...@@ -201,7 +323,10 @@ class LMCacheMPWorkerAdapter:
str, tuple[MessagingFuture[RetrieveResult], list[str]] str, tuple[MessagingFuture[RetrieveResult], list[str]]
] = {} ] = {}
# The store requests that have finished execution in LMCache
self.finished_stores: set[str] = set() self.finished_stores: set[str] = set()
# The finished request ids that are passed via vLLM and also
# have corresponding store requests submitted to LMCache before
self.previously_finished: set[str] = set() self.previously_finished: set[str] = set()
self.model_name = model_name self.model_name = model_name
...@@ -215,7 +340,14 @@ class LMCacheMPWorkerAdapter: ...@@ -215,7 +340,14 @@ class LMCacheMPWorkerAdapter:
) )
self.blocks_in_chunk = chunk_size // vllm_block_size self.blocks_in_chunk = chunk_size // vllm_block_size
def register_kv_caches(self, kv_caches: dict[str, KVCache]): def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Register the kv caches with LMCache server
Args:
kv_caches: A dict of kv caches to register. The keys are the
layer names and the values are the corresponding tensors.
"""
# Register kv cache and send the request # Register kv cache and send the request
self.kv_caches = kv_caches self.kv_caches = kv_caches
logger.info("Registering kv caches") logger.info("Registering kv caches")
...@@ -230,7 +362,29 @@ class LMCacheMPWorkerAdapter: ...@@ -230,7 +362,29 @@ class LMCacheMPWorkerAdapter:
def submit_store_request( def submit_store_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
): ):
keys = self._block_hashes_to_keys(op.block_hashes) """
Submit a KV cache store request to LMCache
Args:
request_id: The ID of the request
op: The LoadStoreOp describing the store operation.
event: The CUDA event that is recorded after the current
model inference step
"""
if op.block_hashes is not None:
# Hash mode
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
]
else:
# Token mode
assert op.token_ids is not None
keys = [
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
]
future = send_lmcache_request( future = send_lmcache_request(
self.mq_client, self.mq_client,
RequestType.STORE, RequestType.STORE,
...@@ -242,7 +396,29 @@ class LMCacheMPWorkerAdapter: ...@@ -242,7 +396,29 @@ class LMCacheMPWorkerAdapter:
def submit_retrieve_request( def submit_retrieve_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
): ):
keys = self._block_hashes_to_keys(op.block_hashes) """
Submit a KV cache retrieve request to LMCache
Args:
request_id: The ID of the request
op: The LoadStoreOp describing the retrieve operation.
event: The CUDA event that is recorded after the current
model inference step
"""
if op.block_hashes is not None:
# Hash mode
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
]
else:
# Token mode
assert op.token_ids is not None
keys = [
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
]
future = send_lmcache_request( future = send_lmcache_request(
self.mq_client, self.mq_client,
RequestType.RETRIEVE, RequestType.RETRIEVE,
...@@ -257,17 +433,47 @@ class LMCacheMPWorkerAdapter: ...@@ -257,17 +433,47 @@ class LMCacheMPWorkerAdapter:
ops: list[LoadStoreOp], ops: list[LoadStoreOp],
event: torch.cuda.Event, event: torch.cuda.Event,
): ):
keys = [] """
block_ids = [] Submit a batched store request to LMCache
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes)) Args:
request_ids: The IDs of the requests
ops: The LoadStoreOps describing the store operations. Should have
the same length as request_ids
event: The CUDA event that is recorded after the current
model inference step
"""
all_keys: list[IPCCacheEngineKey] = []
block_ids: list[int] = []
for request_id, op in zip(request_ids, ops, strict=False):
if op.block_hashes is not None:
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id)
for ch in chunk_hashes
]
all_keys.extend(keys)
else:
assert op.token_ids is not None
all_keys.append(
self._create_key(
op.token_ids, op.start, op.end, request_id=request_id
)
)
block_ids.extend(op.block_ids) block_ids.extend(op.block_ids)
future = send_lmcache_request( future = send_lmcache_request(
self.mq_client, self.mq_client,
RequestType.STORE, RequestType.STORE,
[keys, self.instance_id, block_ids, event.ipc_handle()], [
all_keys,
self.instance_id,
block_ids,
event.ipc_handle(),
],
).to_cuda_future() ).to_cuda_future()
self.store_futures[request_ids[0]] = (future, request_ids[1:]) self.store_futures[request_ids[0]] = (future, list(request_ids[1:]))
@_lmcache_nvtx_annotate @_lmcache_nvtx_annotate
def batched_submit_retrieve_requests( def batched_submit_retrieve_requests(
...@@ -276,34 +482,83 @@ class LMCacheMPWorkerAdapter: ...@@ -276,34 +482,83 @@ class LMCacheMPWorkerAdapter:
ops: list[LoadStoreOp], ops: list[LoadStoreOp],
event: torch.cuda.Event, event: torch.cuda.Event,
): ):
keys = [] """
block_ids = [] Submit a batched retrieve request to LMCache
for op in ops: Args:
keys.extend(self._block_hashes_to_keys(op.block_hashes)) request_ids: The IDs of the requests
ops: The LoadStoreOps describing the retrieve operations. Should have
the same length as request_ids
event: The CUDA event that is recorded after the current
model inference step
"""
all_keys: list[IPCCacheEngineKey] = []
block_ids: list[int] = []
for request_id, op in zip(request_ids, ops, strict=False):
if op.block_hashes is not None:
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id)
for ch in chunk_hashes
]
all_keys.extend(keys)
else:
assert op.token_ids is not None
all_keys.append(
self._create_key(
op.token_ids, op.start, op.end, request_id=request_id
)
)
block_ids.extend(op.block_ids) block_ids.extend(op.block_ids)
future = send_lmcache_request( future = send_lmcache_request(
self.mq_client, self.mq_client,
RequestType.RETRIEVE, RequestType.RETRIEVE,
[keys, self.instance_id, block_ids, event.ipc_handle()], [
all_keys,
self.instance_id,
block_ids,
event.ipc_handle(),
],
).to_cuda_future() ).to_cuda_future()
self.retrieve_futures[request_ids[0]] = (future, request_ids[1:]) self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:]))
@_lmcache_nvtx_annotate @_lmcache_nvtx_annotate
def get_finished( def get_finished(
self, finished_req_ids: set[str] self, finished_req_ids_from_engine: set[str]
) -> tuple[set[str] | None, set[str] | None]: ) -> tuple[set[str] | None, set[str] | None]:
"""
Check and get the finished store and retrieve requests.
Args:
finished_req_ids_from_engine: the set of request ids that are
reported as finished from the vLLM engine side.
Returns:
A tuple of two sets:
- The first set contains the finished store request ids. The returned
store request ids MUST be seen before in the
`finished_req_ids_from_engine`.
- The second set contains the finished retrieve request ids.
Notes:
When enabling async scheduling in vLLM, the same request ID may appear
multiple times in `finished_req_ids_from_engine`. The adapter should
take care of deduplicating the request IDs and only return the request
IDs that have not been returned before.
"""
finished_stores = set() finished_stores = set()
finished_retrieves = set() finished_retrieves = set()
for request_id, (future, other_reqs) in self.store_futures.items(): for request_id, (s_future, other_reqs) in self.store_futures.items():
if not future.query(): if not s_future.query():
continue continue
result = future.result() s_result = s_future.result()
finished_stores.add(request_id) finished_stores.add(request_id)
finished_stores.update(other_reqs) finished_stores.update(other_reqs)
if not result: if not s_result:
# TODO: add error handling here # TODO: add error handling here
logger.error( logger.error(
"Something went wrong when processing the " "Something went wrong when processing the "
...@@ -311,21 +566,21 @@ class LMCacheMPWorkerAdapter: ...@@ -311,21 +566,21 @@ class LMCacheMPWorkerAdapter:
request_id, request_id,
) )
for request_id, (future, other_reqs) in self.retrieve_futures.items(): for request_id, (r_future, other_reqs) in self.retrieve_futures.items():
if not future.query(): if not r_future.query():
continue continue
result = future.result() r_result = r_future.result()
finished_retrieves.add(request_id) finished_retrieves.add(request_id)
finished_retrieves.update(other_reqs) finished_retrieves.update(other_reqs)
if not all(result): if not all(r_result):
# TODO: add error handing here # TODO: add error handing here
logger.error( logger.error(
"Something went wrong when processing the " "Something went wrong when processing the "
"retrieve request for request_id=%s, result=%s", "retrieve request for request_id=%s, result=%s",
request_id, request_id,
result, r_result,
) )
# Remove the finished requests from the tracking dicts # Remove the finished requests from the tracking dicts
...@@ -338,7 +593,7 @@ class LMCacheMPWorkerAdapter: ...@@ -338,7 +593,7 @@ class LMCacheMPWorkerAdapter:
self.finished_stores.update(finished_stores) self.finished_stores.update(finished_stores)
ret_stores = set() ret_stores = set()
for req_id in finished_req_ids: for req_id in finished_req_ids_from_engine:
if req_id in self.finished_stores or req_id in self.store_futures: if req_id in self.finished_stores or req_id in self.store_futures:
self.previously_finished.add(req_id) self.previously_finished.add(req_id)
else: else:
...@@ -357,7 +612,9 @@ class LMCacheMPWorkerAdapter: ...@@ -357,7 +612,9 @@ class LMCacheMPWorkerAdapter:
return self.blocks_in_chunk return self.blocks_in_chunk
def shutdown(self): def shutdown(self):
# Unregister kv cache """
Shutdown the LMCache MP worker adapter
"""
logger.info("Unregistering kv caches") logger.info("Unregistering kv caches")
send_lmcache_request( send_lmcache_request(
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id] self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
...@@ -378,18 +635,32 @@ class LMCacheMPWorkerAdapter: ...@@ -378,18 +635,32 @@ class LMCacheMPWorkerAdapter:
return safe_finished_s return safe_finished_s
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey: def _create_key(
"""Convert a block hash to an IPC cache engine key""" self,
token_ids: list[int],
start: int = 0,
end: int = 0,
request_id: str | None = None,
) -> IPCCacheEngineKey:
"""Convert token IDs to an IPC cache engine key"""
return IPCCacheEngineKey( return IPCCacheEngineKey(
model_name=self.model_name, model_name=self.model_name,
world_size=self.world_size, world_size=self.world_size,
worker_id=self.worker_id, worker_id=self.worker_id,
chunk_hash=block_hash, token_ids=tuple(token_ids),
start=start,
end=end,
request_id=request_id,
) )
def _block_hashes_to_keys( def _create_hash_key(
self, block_hashes: list[bytes] self, chunk_hash: bytes, request_id: str | None = None
) -> list[IPCCacheEngineKey]: ) -> IPCCacheEngineKey:
"""Convert block hashes to IPC cache engine keys""" """Create a hash-mode IPC cache engine key"""
s = striding_block_hashes(block_hashes, self.blocks_in_chunk) return IPCCacheEngineKey(
return [self._create_key(block_hash) for block_hash in s] model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=chunk_hash,
request_id=request_id,
)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import enum import enum
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast from typing import TYPE_CHECKING, Any, Literal
import torch import torch
import zmq import zmq
...@@ -130,12 +130,6 @@ def create_worker_adapter( ...@@ -130,12 +130,6 @@ def create_worker_adapter(
) )
def convert_block_hashes_to_bytes(
block_hashes: list["BlockHash"],
) -> list[bytes]:
return cast(list[bytes], block_hashes)
class LMCacheMPRequestState(enum.Enum): class LMCacheMPRequestState(enum.Enum):
""" """
State machine: State machine:
...@@ -266,6 +260,7 @@ class LMCacheMPRequestMetadata: ...@@ -266,6 +260,7 @@ class LMCacheMPRequestMetadata:
Args: Args:
tracker: The request tracker to generate the metadata from. tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk blocks_in_chunk: the number of blocks in a LMCache data chunk
vllm_block_size: the block size used in vLLM
""" """
# Store the blocks that has block hashes # Store the blocks that has block hashes
# NOTE: the invariant here is that `num_stored_blocks` should # NOTE: the invariant here is that `num_stored_blocks` should
...@@ -282,15 +277,21 @@ class LMCacheMPRequestMetadata: ...@@ -282,15 +277,21 @@ class LMCacheMPRequestMetadata:
if num_chunks >= 1: if num_chunks >= 1:
start = tracker.num_stored_blocks start = tracker.num_stored_blocks
end = start + num_chunks * blocks_in_chunk end = start + num_chunks * blocks_in_chunk
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end] block_ids = tracker.allocated_block_ids[start:end]
start_token_idx = start * vllm_block_size
end_token_idx = end * vllm_block_size
token_ids = list(tracker.all_token_ids)
op = LoadStoreOp(
token_ids=token_ids,
block_ids=block_ids,
start=start_token_idx,
end=end_token_idx,
)
ret = LMCacheMPRequestMetadata( ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id, request_id=tracker.request_id,
direction="STORE", direction="STORE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids), op=op,
) )
# Update the request tracker # Update the request tracker
...@@ -303,6 +304,7 @@ class LMCacheMPRequestMetadata: ...@@ -303,6 +304,7 @@ class LMCacheMPRequestMetadata:
def GetRetrieveMetadata( def GetRetrieveMetadata(
tracker: LMCacheMPRequestTracker, tracker: LMCacheMPRequestTracker,
blocks_in_chunk: int, blocks_in_chunk: int,
vllm_block_size: int,
) -> "LMCacheMPRequestMetadata | None": ) -> "LMCacheMPRequestMetadata | None":
""" """
Generate the retrieve metadata for the current request tracker. Generate the retrieve metadata for the current request tracker.
...@@ -310,6 +312,7 @@ class LMCacheMPRequestMetadata: ...@@ -310,6 +312,7 @@ class LMCacheMPRequestMetadata:
Args: Args:
tracker: The request tracker to generate the metadata from. tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk blocks_in_chunk: the number of blocks in a LMCache data chunk
vllm_block_size: the block size used in vLLM
""" """
if not tracker.is_ready_for_retrieving(): if not tracker.is_ready_for_retrieving():
return None return None
...@@ -330,15 +333,21 @@ class LMCacheMPRequestMetadata: ...@@ -330,15 +333,21 @@ class LMCacheMPRequestMetadata:
"number of LMCache hit blocks. " "number of LMCache hit blocks. "
) )
if end > start: if end > start:
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end] block_ids = tracker.allocated_block_ids[start:end]
start_token_idx = start * vllm_block_size
end_token_idx = end * vllm_block_size
token_ids = list(tracker.all_token_ids)
op = LoadStoreOp(
token_ids=token_ids,
block_ids=block_ids,
start=start_token_idx,
end=end_token_idx,
)
ret = LMCacheMPRequestMetadata( ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id, request_id=tracker.request_id,
direction="RETRIEVE", direction="RETRIEVE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids), op=op,
) )
return ret return ret
...@@ -643,7 +652,8 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -643,7 +652,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
return 0, False return 0, False
self.scheduler_adapter.maybe_submit_lookup_request( self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes) request.request_id,
token_ids=list(request.all_token_ids),
) )
ret = self.scheduler_adapter.check_lookup_result(request.request_id) ret = self.scheduler_adapter.check_lookup_result(request.request_id)
...@@ -766,6 +776,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -766,6 +776,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
""" """
# Clean up request tracker to prevent memory leak # Clean up request tracker to prevent memory leak
self._cleanup_request_tracker(request.request_id) self._cleanup_request_tracker(request.request_id)
# Notify LMCache to end the session for this request
self.scheduler_adapter.end_session(request.request_id)
return True, None return True, None
def take_events(self) -> Iterable["KVCacheEvent"]: def take_events(self) -> Iterable["KVCacheEvent"]:
...@@ -846,7 +859,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): ...@@ -846,7 +859,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD: if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
continue continue
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata( r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
request_tracker, blocks_per_chunk request_tracker,
blocks_per_chunk,
vllm_block_size=self.vllm_block_size,
) )
if r_metadata is not None: if r_metadata is not None:
metadata.add_request_metadata(r_metadata) metadata.add_request_metadata(r_metadata)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment