Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e94ec597
Unverified
Commit
e94ec597
authored
Feb 09, 2026
by
Yuwei An
Committed by
GitHub
Feb 10, 2026
Browse files
[LMCache] Token Base IPC API (#34175)
Signed-off-by:
Oasis-Git
<
ayw.sirius19@gmail.com
>
parent
13397841
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
376 additions
and
90 deletions
+376
-90
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
...connector/v1/lmcache_integration/multi_process_adapter.py
+344
-73
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
...buted/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+32
-17
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
View file @
e94ec597
...
...
@@ -20,16 +20,42 @@ from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
logger
=
init_logger
(
__name__
)
def
wrap_kv_caches
(
kv_caches
:
dict
[
str
,
KVCache
])
->
KVCache
:
def
wrap_kv_caches
(
kv_caches
:
dict
[
str
,
torch
.
Tensor
])
->
KVCache
:
logger
.
info
(
"KV caches keys are %s"
,
list
(
kv_caches
.
keys
()))
return
[
CudaIPCWrapper
(
tensor
)
for
tensor
in
kv_caches
.
values
()]
def
striding_block_hashes
(
block_hashes
:
list
[
bytes
],
blocks_in_chunk
:
int
)
->
Iterable
[
bytes
]:
"""Extract chunk-level hashes from block hashes by striding.
In hash-based vLLM, each vLLM block has its own hash. LMCache chunks
span ``blocks_in_chunk`` consecutive blocks. The representative hash
for a chunk is the hash of the **last** block in that chunk (because
each block hash already encodes its prefix). So we start at index
``blocks_in_chunk - 1`` and stride by ``blocks_in_chunk``.
"""
return
islice
(
block_hashes
,
blocks_in_chunk
-
1
,
None
,
blocks_in_chunk
)
def
send_lmcache_request
(
mq_client
:
MessageQueueClient
,
request_type
:
RequestType
,
payloads
:
list
[
Any
],
)
->
MessagingFuture
[
Any
]:
"""
Helper function to send the request to the LMCache multiprocess server
Args:
mq_client: The LMCache multiprocess mode message queue client
request_type: The request type
payloads: The request payloads
Returns:
A messaging future for the request
"""
future
=
mq_client
.
submit_request
(
request_type
,
payloads
,
get_response_class
(
request_type
)
)
...
...
@@ -39,40 +65,44 @@ def send_lmcache_request(
def
get_lmcache_chunk_size
(
mq_client
:
MessageQueueClient
,
)
->
int
:
future
=
send_lmcache_request
(
mq_client
,
RequestType
.
GET_CHUNK_SIZE
,
[])
chunk_size
=
future
.
result
()
return
chunk_size
"""
Helper function to get the LMCache chunk size from the server
Args:
mq_client: The LMCache multiprocess mode message queue client
def
striding_block_hashes
(
block_hashes
:
list
[
bytes
],
blocks_in_chunk
,
)
->
Iterable
[
bytes
]:
"""Striding the block hashes to get the block hashes for each chunk.
For example, if blocks_in_chunk is 16, then we will get the block hashes
for the 16th, 32nd, 48th, ... blocks.
Returns:
An integer representing the LMCache chunk size
"""
return
islice
(
block_hashes
,
blocks_in_chunk
-
1
,
None
,
blocks_in_chunk
)
future
=
send_lmcache_request
(
mq_client
,
RequestType
.
GET_CHUNK_SIZE
,
[])
chunk_size
=
future
.
result
()
return
chunk_size
@
dataclass
class
LoadStoreOp
:
block_hashes
:
list
[
bytes
]
block_ids
:
list
[
int
]
"""Block ids for the load/store operation"""
def
__len__
(
self
)
->
int
:
return
len
(
self
.
block_hashes
)
token_ids
:
list
[
int
]
|
None
=
None
"""Token IDs for the load/store operation (token mode)"""
def
__post_init__
(
self
):
assert
len
(
self
.
block_hashes
)
==
len
(
self
.
block_ids
),
(
"The number of block hashes should be equal to the number of block ids "
f
"But got
{
len
(
self
.
block_hashes
)
}
and
{
len
(
self
.
block_ids
)
}
"
)
block_hashes
:
list
[
bytes
]
|
None
=
None
"""Block hashes for the load/store operation (hash mode)"""
start
:
int
=
0
"""Start token index (token mode only)"""
end
:
int
=
0
"""End token index (token mode only)"""
def
__len__
(
self
)
->
int
:
return
len
(
self
.
block_ids
)
StoreResult
=
bool
RetrieveResult
=
list
[
bool
]
LookupResult
=
list
[
bool
]
LookupResult
=
int
class
LMCacheMPSchedulerAdapter
:
...
...
@@ -95,10 +125,6 @@ class LMCacheMPSchedulerAdapter:
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM
"""
logger
.
warning
(
"Importing LMCacheMPSchedulerAdapter is deprecated. "
"Please update your LMCache to the latest version."
)
self
.
mq_client
=
MessageQueueClient
(
server_url
,
context
)
# Request futures
...
...
@@ -116,22 +142,89 @@ class LMCacheMPSchedulerAdapter:
self
.
blocks_in_chunk
=
self
.
chunk_size
//
vllm_block_size
@
_lmcache_nvtx_annotate
def
maybe_submit_lookup_request
(
self
,
request_id
:
str
,
block_hashes
:
list
[
bytes
]):
def
maybe_submit_lookup_request
(
self
,
request_id
:
str
,
block_hashes
:
list
[
bytes
]
|
None
=
None
,
token_ids
:
list
[
int
]
|
None
=
None
,
)
->
None
:
"""
Submit a new lookup request to LMCache if there is no ongoing request.
Supports both token-based and hash-based vLLM:
- token_ids: token IDs (token-based vLLM) -> single token-mode key
- block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys
Exactly one of block_hashes or token_ids must be provided.
Args:
request_id: The ID of the lookup request. The same ID indicates it's
from the same request
block_hashes: Block hashes to lookup from LMCache (hash mode)
token_ids: Token IDs to lookup from LMCache (token mode)
Returns:
None
Notes:
This function will have a side-effect: submitting a look up request to
LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
for later retrieve operations.
In the meantime, this function will record the lookup request, and the
status of the look up request can be checked by `check_lookup_result`.
"""
if
request_id
in
self
.
lookup_futures
:
# Skip if there is already a lookup request
return
s
=
striding_block_hashes
(
block_hashes
,
self
.
blocks_in_chunk
)
keys
=
[
self
.
_create_key
(
block_hash
)
for
block_hash
in
s
]
assert
(
block_hashes
is
None
)
!=
(
token_ids
is
None
),
(
"Exactly one of block_hashes or token_ids must be provided"
)
if
block_hashes
is
not
None
:
# Hash mode: stride block hashes -> N hash-mode keys
chunk_hashes
=
list
(
striding_block_hashes
(
block_hashes
,
self
.
blocks_in_chunk
)
)
keys
=
[
self
.
_create_hash_key
(
ch
,
request_id
=
request_id
)
for
ch
in
chunk_hashes
]
else
:
# Token mode: truncate to chunk-aligned length
assert
token_ids
is
not
None
aligned_end
=
(
len
(
token_ids
)
//
self
.
chunk_size
)
*
self
.
chunk_size
if
aligned_end
==
0
:
return
keys
=
[
self
.
_create_key
(
token_ids
,
start
=
0
,
end
=
aligned_end
,
request_id
=
request_id
,
).
no_worker_id_version
()
]
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
LOOKUP
,
[
keys
,
True
],
[
keys
],
)
self
.
lookup_futures
[
request_id
]
=
future
@
_lmcache_nvtx_annotate
def
check_lookup_result
(
self
,
request_id
:
str
)
->
int
|
None
:
"""
Check the result of a previously submitted lookup request.
Args:
request_id: The ID of the lookup request submitted in
`maybe_submit_lookup_request`
Returns:
An integer representing the total number of tokens matched
in LMCache (prefix matching), or
None if the lookup request is not finished yet.
"""
assert
request_id
in
self
.
lookup_futures
,
(
f
"Lookup request for request_id=
{
request_id
}
has not been submitted"
)
...
...
@@ -141,7 +234,7 @@ class LMCacheMPSchedulerAdapter:
return
None
result
=
future
.
result
()
num_chunks
=
sum
(
result
)
num_chunks
=
result
return
num_chunks
*
self
.
chunk_size
def
num_blocks_per_chunk
(
self
)
->
int
:
...
...
@@ -159,14 +252,47 @@ class LMCacheMPSchedulerAdapter:
"""
self
.
lookup_futures
.
pop
(
request_id
,
None
)
def
end_session
(
self
,
request_id
:
str
)
->
None
:
"""
Notify LMCache server to remove the session for a finished request.
Args:
request_id: The ID of the finished request.
"""
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
END_SESSION
,
[
request_id
],
)
# Helper functions
def
_create_key
(
self
,
block_hash
:
bytes
)
->
IPCCacheEngineKey
:
"""Convert a block hash to an IPC cache engine key"""
def
_create_key
(
self
,
token_ids
:
list
[
int
],
start
:
int
=
0
,
end
:
int
=
0
,
request_id
:
str
|
None
=
None
,
)
->
IPCCacheEngineKey
:
"""Convert token IDs to an IPC cache engine key"""
return
IPCCacheEngineKey
(
model_name
=
self
.
model_name
,
world_size
=
self
.
world_size
,
worker_id
=
self
.
worker_id
,
chunk_hash
=
block_hash
,
token_ids
=
tuple
(
token_ids
),
start
=
start
,
end
=
end
,
request_id
=
request_id
,
)
def
_create_hash_key
(
self
,
chunk_hash
:
bytes
,
request_id
:
str
|
None
=
None
)
->
IPCCacheEngineKey
:
"""Create a hash-mode IPC cache engine key"""
return
IPCCacheEngineKey
(
model_name
=
self
.
model_name
,
world_size
=
self
.
world_size
,
worker_id
=
None
,
chunk_hash
=
chunk_hash
,
request_id
=
request_id
,
)
...
...
@@ -180,10 +306,6 @@ class LMCacheMPWorkerAdapter:
kv_rank
:
int
,
vllm_block_size
:
int
,
):
logger
.
warning
(
"Importing LMCacheMPWorkerAdapter is deprecated. "
"Please update your LMCache to the latest version."
)
self
.
mq_client
=
MessageQueueClient
(
server_url
,
context
)
# Instance id for GPU worker
...
...
@@ -201,7 +323,10 @@ class LMCacheMPWorkerAdapter:
str
,
tuple
[
MessagingFuture
[
RetrieveResult
],
list
[
str
]]
]
=
{}
# The store requests that have finished execution in LMCache
self
.
finished_stores
:
set
[
str
]
=
set
()
# The finished request ids that are passed via vLLM and also
# have corresponding store requests submitted to LMCache before
self
.
previously_finished
:
set
[
str
]
=
set
()
self
.
model_name
=
model_name
...
...
@@ -215,7 +340,14 @@ class LMCacheMPWorkerAdapter:
)
self
.
blocks_in_chunk
=
chunk_size
//
vllm_block_size
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
KVCache
]):
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""
Register the kv caches with LMCache server
Args:
kv_caches: A dict of kv caches to register. The keys are the
layer names and the values are the corresponding tensors.
"""
# Register kv cache and send the request
self
.
kv_caches
=
kv_caches
logger
.
info
(
"Registering kv caches"
)
...
...
@@ -230,7 +362,29 @@ class LMCacheMPWorkerAdapter:
def
submit_store_request
(
self
,
request_id
:
str
,
op
:
LoadStoreOp
,
event
:
torch
.
cuda
.
Event
):
keys
=
self
.
_block_hashes_to_keys
(
op
.
block_hashes
)
"""
Submit a KV cache store request to LMCache
Args:
request_id: The ID of the request
op: The LoadStoreOp describing the store operation.
event: The CUDA event that is recorded after the current
model inference step
"""
if
op
.
block_hashes
is
not
None
:
# Hash mode
chunk_hashes
=
list
(
striding_block_hashes
(
op
.
block_hashes
,
self
.
blocks_in_chunk
)
)
keys
=
[
self
.
_create_hash_key
(
ch
,
request_id
=
request_id
)
for
ch
in
chunk_hashes
]
else
:
# Token mode
assert
op
.
token_ids
is
not
None
keys
=
[
self
.
_create_key
(
op
.
token_ids
,
op
.
start
,
op
.
end
,
request_id
=
request_id
)
]
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
STORE
,
...
...
@@ -242,7 +396,29 @@ class LMCacheMPWorkerAdapter:
def
submit_retrieve_request
(
self
,
request_id
:
str
,
op
:
LoadStoreOp
,
event
:
torch
.
cuda
.
Event
):
keys
=
self
.
_block_hashes_to_keys
(
op
.
block_hashes
)
"""
Submit a KV cache retrieve request to LMCache
Args:
request_id: The ID of the request
op: The LoadStoreOp describing the retrieve operation.
event: The CUDA event that is recorded after the current
model inference step
"""
if
op
.
block_hashes
is
not
None
:
# Hash mode
chunk_hashes
=
list
(
striding_block_hashes
(
op
.
block_hashes
,
self
.
blocks_in_chunk
)
)
keys
=
[
self
.
_create_hash_key
(
ch
,
request_id
=
request_id
)
for
ch
in
chunk_hashes
]
else
:
# Token mode
assert
op
.
token_ids
is
not
None
keys
=
[
self
.
_create_key
(
op
.
token_ids
,
op
.
start
,
op
.
end
,
request_id
=
request_id
)
]
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
RETRIEVE
,
...
...
@@ -257,17 +433,47 @@ class LMCacheMPWorkerAdapter:
ops
:
list
[
LoadStoreOp
],
event
:
torch
.
cuda
.
Event
,
):
keys
=
[]
block_ids
=
[]
for
op
in
ops
:
keys
.
extend
(
self
.
_block_hashes_to_keys
(
op
.
block_hashes
))
"""
Submit a batched store request to LMCache
Args:
request_ids: The IDs of the requests
ops: The LoadStoreOps describing the store operations. Should have
the same length as request_ids
event: The CUDA event that is recorded after the current
model inference step
"""
all_keys
:
list
[
IPCCacheEngineKey
]
=
[]
block_ids
:
list
[
int
]
=
[]
for
request_id
,
op
in
zip
(
request_ids
,
ops
,
strict
=
False
):
if
op
.
block_hashes
is
not
None
:
chunk_hashes
=
list
(
striding_block_hashes
(
op
.
block_hashes
,
self
.
blocks_in_chunk
)
)
keys
=
[
self
.
_create_hash_key
(
ch
,
request_id
=
request_id
)
for
ch
in
chunk_hashes
]
all_keys
.
extend
(
keys
)
else
:
assert
op
.
token_ids
is
not
None
all_keys
.
append
(
self
.
_create_key
(
op
.
token_ids
,
op
.
start
,
op
.
end
,
request_id
=
request_id
)
)
block_ids
.
extend
(
op
.
block_ids
)
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
STORE
,
[
keys
,
self
.
instance_id
,
block_ids
,
event
.
ipc_handle
()],
[
all_keys
,
self
.
instance_id
,
block_ids
,
event
.
ipc_handle
(),
],
).
to_cuda_future
()
self
.
store_futures
[
request_ids
[
0
]]
=
(
future
,
request_ids
[
1
:])
self
.
store_futures
[
request_ids
[
0
]]
=
(
future
,
list
(
request_ids
[
1
:])
)
@
_lmcache_nvtx_annotate
def
batched_submit_retrieve_requests
(
...
...
@@ -276,34 +482,83 @@ class LMCacheMPWorkerAdapter:
ops
:
list
[
LoadStoreOp
],
event
:
torch
.
cuda
.
Event
,
):
keys
=
[]
block_ids
=
[]
"""
Submit a batched retrieve request to LMCache
for
op
in
ops
:
keys
.
extend
(
self
.
_block_hashes_to_keys
(
op
.
block_hashes
))
Args:
request_ids: The IDs of the requests
ops: The LoadStoreOps describing the retrieve operations. Should have
the same length as request_ids
event: The CUDA event that is recorded after the current
model inference step
"""
all_keys
:
list
[
IPCCacheEngineKey
]
=
[]
block_ids
:
list
[
int
]
=
[]
for
request_id
,
op
in
zip
(
request_ids
,
ops
,
strict
=
False
):
if
op
.
block_hashes
is
not
None
:
chunk_hashes
=
list
(
striding_block_hashes
(
op
.
block_hashes
,
self
.
blocks_in_chunk
)
)
keys
=
[
self
.
_create_hash_key
(
ch
,
request_id
=
request_id
)
for
ch
in
chunk_hashes
]
all_keys
.
extend
(
keys
)
else
:
assert
op
.
token_ids
is
not
None
all_keys
.
append
(
self
.
_create_key
(
op
.
token_ids
,
op
.
start
,
op
.
end
,
request_id
=
request_id
)
)
block_ids
.
extend
(
op
.
block_ids
)
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
RETRIEVE
,
[
keys
,
self
.
instance_id
,
block_ids
,
event
.
ipc_handle
()],
[
all_keys
,
self
.
instance_id
,
block_ids
,
event
.
ipc_handle
(),
],
).
to_cuda_future
()
self
.
retrieve_futures
[
request_ids
[
0
]]
=
(
future
,
request_ids
[
1
:])
self
.
retrieve_futures
[
request_ids
[
0
]]
=
(
future
,
list
(
request_ids
[
1
:])
)
@
_lmcache_nvtx_annotate
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
]
self
,
finished_req_ids
_from_engine
:
set
[
str
]
)
->
tuple
[
set
[
str
]
|
None
,
set
[
str
]
|
None
]:
"""
Check and get the finished store and retrieve requests.
Args:
finished_req_ids_from_engine: the set of request ids that are
reported as finished from the vLLM engine side.
Returns:
A tuple of two sets:
- The first set contains the finished store request ids. The returned
store request ids MUST be seen before in the
`finished_req_ids_from_engine`.
- The second set contains the finished retrieve request ids.
Notes:
When enabling async scheduling in vLLM, the same request ID may appear
multiple times in `finished_req_ids_from_engine`. The adapter should
take care of deduplicating the request IDs and only return the request
IDs that have not been returned before.
"""
finished_stores
=
set
()
finished_retrieves
=
set
()
for
request_id
,
(
future
,
other_reqs
)
in
self
.
store_futures
.
items
():
if
not
future
.
query
():
for
request_id
,
(
s_
future
,
other_reqs
)
in
self
.
store_futures
.
items
():
if
not
s_
future
.
query
():
continue
result
=
future
.
result
()
s_
result
=
s_
future
.
result
()
finished_stores
.
add
(
request_id
)
finished_stores
.
update
(
other_reqs
)
if
not
result
:
if
not
s_
result
:
# TODO: add error handling here
logger
.
error
(
"Something went wrong when processing the "
...
...
@@ -311,21 +566,21 @@ class LMCacheMPWorkerAdapter:
request_id
,
)
for
request_id
,
(
future
,
other_reqs
)
in
self
.
retrieve_futures
.
items
():
if
not
future
.
query
():
for
request_id
,
(
r_
future
,
other_reqs
)
in
self
.
retrieve_futures
.
items
():
if
not
r_
future
.
query
():
continue
result
=
future
.
result
()
r_
result
=
r_
future
.
result
()
finished_retrieves
.
add
(
request_id
)
finished_retrieves
.
update
(
other_reqs
)
if
not
all
(
result
):
if
not
all
(
r_
result
):
# TODO: add error handing here
logger
.
error
(
"Something went wrong when processing the "
"retrieve request for request_id=%s, result=%s"
,
request_id
,
result
,
r_
result
,
)
# Remove the finished requests from the tracking dicts
...
...
@@ -338,7 +593,7 @@ class LMCacheMPWorkerAdapter:
self
.
finished_stores
.
update
(
finished_stores
)
ret_stores
=
set
()
for
req_id
in
finished_req_ids
:
for
req_id
in
finished_req_ids
_from_engine
:
if
req_id
in
self
.
finished_stores
or
req_id
in
self
.
store_futures
:
self
.
previously_finished
.
add
(
req_id
)
else
:
...
...
@@ -357,7 +612,9 @@ class LMCacheMPWorkerAdapter:
return
self
.
blocks_in_chunk
def
shutdown
(
self
):
# Unregister kv cache
"""
Shutdown the LMCache MP worker adapter
"""
logger
.
info
(
"Unregistering kv caches"
)
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
UNREGISTER_KV_CACHE
,
[
self
.
instance_id
]
...
...
@@ -378,18 +635,32 @@ class LMCacheMPWorkerAdapter:
return
safe_finished_s
def
_create_key
(
self
,
block_hash
:
bytes
)
->
IPCCacheEngineKey
:
"""Convert a block hash to an IPC cache engine key"""
def
_create_key
(
self
,
token_ids
:
list
[
int
],
start
:
int
=
0
,
end
:
int
=
0
,
request_id
:
str
|
None
=
None
,
)
->
IPCCacheEngineKey
:
"""Convert token IDs to an IPC cache engine key"""
return
IPCCacheEngineKey
(
model_name
=
self
.
model_name
,
world_size
=
self
.
world_size
,
worker_id
=
self
.
worker_id
,
chunk_hash
=
block_hash
,
token_ids
=
tuple
(
token_ids
),
start
=
start
,
end
=
end
,
request_id
=
request_id
,
)
def
_block_hashes_to_keys
(
self
,
block_hashes
:
list
[
bytes
]
)
->
list
[
IPCCacheEngineKey
]:
"""Convert block hashes to IPC cache engine keys"""
s
=
striding_block_hashes
(
block_hashes
,
self
.
blocks_in_chunk
)
return
[
self
.
_create_key
(
block_hash
)
for
block_hash
in
s
]
def
_create_hash_key
(
self
,
chunk_hash
:
bytes
,
request_id
:
str
|
None
=
None
)
->
IPCCacheEngineKey
:
"""Create a hash-mode IPC cache engine key"""
return
IPCCacheEngineKey
(
model_name
=
self
.
model_name
,
world_size
=
self
.
world_size
,
worker_id
=
self
.
worker_id
,
chunk_hash
=
chunk_hash
,
request_id
=
request_id
,
)
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
View file @
e94ec597
...
...
@@ -3,7 +3,7 @@
import
enum
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
cast
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
import
torch
import
zmq
...
...
@@ -130,12 +130,6 @@ def create_worker_adapter(
)
def
convert_block_hashes_to_bytes
(
block_hashes
:
list
[
"BlockHash"
],
)
->
list
[
bytes
]:
return
cast
(
list
[
bytes
],
block_hashes
)
class
LMCacheMPRequestState
(
enum
.
Enum
):
"""
State machine:
...
...
@@ -266,6 +260,7 @@ class LMCacheMPRequestMetadata:
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
vllm_block_size: the block size used in vLLM
"""
# Store the blocks that has block hashes
# NOTE: the invariant here is that `num_stored_blocks` should
...
...
@@ -282,15 +277,21 @@ class LMCacheMPRequestMetadata:
if
num_chunks
>=
1
:
start
=
tracker
.
num_stored_blocks
end
=
start
+
num_chunks
*
blocks_in_chunk
block_hashes
=
convert_block_hashes_to_bytes
(
tracker
.
block_hashes
[
start
:
end
]
)
block_ids
=
tracker
.
allocated_block_ids
[
start
:
end
]
start_token_idx
=
start
*
vllm_block_size
end_token_idx
=
end
*
vllm_block_size
token_ids
=
list
(
tracker
.
all_token_ids
)
op
=
LoadStoreOp
(
token_ids
=
token_ids
,
block_ids
=
block_ids
,
start
=
start_token_idx
,
end
=
end_token_idx
,
)
ret
=
LMCacheMPRequestMetadata
(
request_id
=
tracker
.
request_id
,
direction
=
"STORE"
,
op
=
LoadStoreOp
(
block_hashes
=
block_hashes
,
block_ids
=
block_ids
)
,
op
=
op
,
)
# Update the request tracker
...
...
@@ -303,6 +304,7 @@ class LMCacheMPRequestMetadata:
def
GetRetrieveMetadata
(
tracker
:
LMCacheMPRequestTracker
,
blocks_in_chunk
:
int
,
vllm_block_size
:
int
,
)
->
"LMCacheMPRequestMetadata | None"
:
"""
Generate the retrieve metadata for the current request tracker.
...
...
@@ -310,6 +312,7 @@ class LMCacheMPRequestMetadata:
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
vllm_block_size: the block size used in vLLM
"""
if
not
tracker
.
is_ready_for_retrieving
():
return
None
...
...
@@ -330,15 +333,21 @@ class LMCacheMPRequestMetadata:
"number of LMCache hit blocks. "
)
if
end
>
start
:
block_hashes
=
convert_block_hashes_to_bytes
(
tracker
.
block_hashes
[
start
:
end
]
)
block_ids
=
tracker
.
allocated_block_ids
[
start
:
end
]
start_token_idx
=
start
*
vllm_block_size
end_token_idx
=
end
*
vllm_block_size
token_ids
=
list
(
tracker
.
all_token_ids
)
op
=
LoadStoreOp
(
token_ids
=
token_ids
,
block_ids
=
block_ids
,
start
=
start_token_idx
,
end
=
end_token_idx
,
)
ret
=
LMCacheMPRequestMetadata
(
request_id
=
tracker
.
request_id
,
direction
=
"RETRIEVE"
,
op
=
LoadStoreOp
(
block_hashes
=
block_hashes
,
block_ids
=
block_ids
)
,
op
=
op
,
)
return
ret
...
...
@@ -643,7 +652,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
return
0
,
False
self
.
scheduler_adapter
.
maybe_submit_lookup_request
(
request
.
request_id
,
convert_block_hashes_to_bytes
(
request
.
block_hashes
)
request
.
request_id
,
token_ids
=
list
(
request
.
all_token_ids
),
)
ret
=
self
.
scheduler_adapter
.
check_lookup_result
(
request
.
request_id
)
...
...
@@ -766,6 +776,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
"""
# Clean up request tracker to prevent memory leak
self
.
_cleanup_request_tracker
(
request
.
request_id
)
# Notify LMCache to end the session for this request
self
.
scheduler_adapter
.
end_session
(
request
.
request_id
)
return
True
,
None
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
...
...
@@ -846,7 +859,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
if
request_tracker
.
state
!=
LMCacheMPRequestState
.
WAITING_FOR_LOAD
:
continue
r_metadata
=
LMCacheMPRequestMetadata
.
GetRetrieveMetadata
(
request_tracker
,
blocks_per_chunk
request_tracker
,
blocks_per_chunk
,
vllm_block_size
=
self
.
vllm_block_size
,
)
if
r_metadata
is
not
None
:
metadata
.
add_request_metadata
(
r_metadata
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment