Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
94a9ebcf
Unverified
Commit
94a9ebcf
authored
Nov 12, 2025
by
Yihua Cheng
Committed by
GitHub
Nov 12, 2025
Browse files
[KV connector][WIP] KV cache proxy based on LMCache multi-process mode (#27902)
Signed-off-by:
ApostaC
<
yihua98@uchicago.edu
>
parent
a39dd7bb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1265 additions
and
2 deletions
+1265
-2
vllm/distributed/kv_transfer/kv_connector/factory.py
vllm/distributed/kv_transfer/kv_connector/factory.py
+6
-0
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py
..._transfer/kv_connector/v1/lmcache_integration/__init__.py
+13
-2
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
...connector/v1/lmcache_integration/multi_process_adapter.py
+379
-0
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
...buted/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
+867
-0
No files found.
vllm/distributed/kv_transfer/kv_connector/factory.py
View file @
94a9ebcf
...
...
@@ -161,6 +161,12 @@ KVConnectorFactory.register_connector(
"LMCacheConnectorV1"
,
)
KVConnectorFactory
.
register_connector
(
"LMCacheMPConnector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.lmcache_mp_connector"
,
"LMCacheMPConnector"
,
)
KVConnectorFactory
.
register_connector
(
"NixlConnector"
,
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector"
,
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py
View file @
94a9ebcf
...
...
@@ -2,6 +2,17 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.
import
vllm_v1_adapter
from
.
import
multi_process_adapter
,
vllm_v1_adapter
from
.multi_process_adapter
import
(
LMCacheMPSchedulerAdapter
,
LMCacheMPWorkerAdapter
,
LoadStoreOp
,
)
__all__
=
[
"vllm_v1_adapter"
]
__all__
=
[
"vllm_v1_adapter"
,
"multi_process_adapter"
,
"LMCacheMPSchedulerAdapter"
,
"LMCacheMPWorkerAdapter"
,
"LoadStoreOp"
,
]
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py
0 → 100644
View file @
94a9ebcf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
os
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
from
itertools
import
islice
from
typing
import
Any
import
torch
import
zmq
from
lmcache.utils
import
_lmcache_nvtx_annotate
,
init_logger
from
lmcache.v1.multiprocess.custom_types
import
(
CudaIPCWrapper
,
IPCCacheEngineKey
,
KVCache
,
)
from
lmcache.v1.multiprocess.mq
import
MessageQueueClient
,
MessagingFuture
from
lmcache.v1.multiprocess.protocol
import
RequestType
,
get_response_class
logger
=
init_logger
(
__name__
)
def
wrap_kv_caches
(
kv_caches
:
dict
[
str
,
KVCache
])
->
KVCache
:
logger
.
info
(
"KV caches keys are %s"
,
list
(
kv_caches
.
keys
()))
return
[
CudaIPCWrapper
(
tensor
)
for
tensor
in
kv_caches
.
values
()]
def
send_lmcache_request
(
mq_client
:
MessageQueueClient
,
request_type
:
RequestType
,
payloads
:
list
[
Any
],
)
->
MessagingFuture
[
Any
]:
future
=
mq_client
.
submit_request
(
request_type
,
payloads
,
get_response_class
(
request_type
)
)
return
future
def
get_lmcache_chunk_size
(
mq_client
:
MessageQueueClient
,
)
->
int
:
future
=
send_lmcache_request
(
mq_client
,
RequestType
.
GET_CHUNK_SIZE
,
[])
chunk_size
=
future
.
result
()
return
chunk_size
def
striding_block_hashes
(
block_hashes
:
list
[
bytes
],
blocks_in_chunk
,
)
->
Iterable
[
bytes
]:
"""Striding the block hashes to get the block hashes for each chunk.
For example, if blocks_in_chunk is 16, then we will get the block hashes
for the 16th, 32nd, 48th, ... blocks.
"""
return
islice
(
block_hashes
,
blocks_in_chunk
-
1
,
None
,
blocks_in_chunk
)
@
dataclass
class
LoadStoreOp
:
block_hashes
:
list
[
bytes
]
block_ids
:
list
[
int
]
def
__len__
(
self
)
->
int
:
return
len
(
self
.
block_hashes
)
def
__post_init__
(
self
):
assert
len
(
self
.
block_hashes
)
==
len
(
self
.
block_ids
),
(
"The number of block hashes should be equal to the number of block ids "
f
"But got
{
len
(
self
.
block_hashes
)
}
and
{
len
(
self
.
block_ids
)
}
"
)
StoreResult
=
bool
RetrieveResult
=
list
[
bool
]
LookupResult
=
list
[
bool
]
class
LMCacheMPSchedulerAdapter
:
def
__init__
(
self
,
server_url
:
str
,
context
:
zmq
.
Context
,
model_name
:
str
,
world_size
:
int
,
kv_rank
:
int
,
vllm_block_size
:
int
,
):
"""
Args:
server_url: The server URL for the LMCache message queue
context: The ZMQ context
model_name: The model name used for LMCache keys
world_size: The world size used for LMCache keys
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM
"""
self
.
mq_client
=
MessageQueueClient
(
server_url
,
context
)
# Request futures
self
.
lookup_futures
:
dict
[
str
,
MessagingFuture
[
LookupResult
]]
=
{}
self
.
model_name
=
model_name
self
.
world_size
=
world_size
self
.
worker_id
=
kv_rank
# Read chunk size from lmcache
self
.
chunk_size
=
get_lmcache_chunk_size
(
self
.
mq_client
)
assert
self
.
chunk_size
%
vllm_block_size
==
0
,
(
"LMCache chunk size should be a multiple of vLLM block size"
)
self
.
blocks_in_chunk
=
self
.
chunk_size
//
vllm_block_size
@
_lmcache_nvtx_annotate
def
maybe_submit_lookup_request
(
self
,
request_id
:
str
,
block_hashes
:
list
[
bytes
]):
if
request_id
in
self
.
lookup_futures
:
# Skip if there is already a lookup request
return
s
=
striding_block_hashes
(
block_hashes
,
self
.
blocks_in_chunk
)
keys
=
[
self
.
_create_key
(
block_hash
)
for
block_hash
in
s
]
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
LOOKUP
,
[
keys
,
True
],
)
self
.
lookup_futures
[
request_id
]
=
future
@
_lmcache_nvtx_annotate
def
check_lookup_result
(
self
,
request_id
:
str
)
->
int
|
None
:
assert
request_id
in
self
.
lookup_futures
,
(
f
"Lookup request for request_id=
{
request_id
}
has not been submitted"
)
future
=
self
.
lookup_futures
[
request_id
]
if
not
future
.
query
():
return
None
result
=
future
.
result
()
num_chunks
=
sum
(
result
)
return
num_chunks
*
self
.
chunk_size
def
num_blocks_per_chunk
(
self
)
->
int
:
"""
Returns:
The number of vllm blocks in a LMCache data chunk
"""
return
self
.
blocks_in_chunk
# Helper functions
def
_create_key
(
self
,
block_hash
:
bytes
)
->
IPCCacheEngineKey
:
"""Convert a block hash to an IPC cache engine key"""
return
IPCCacheEngineKey
(
model_name
=
self
.
model_name
,
world_size
=
self
.
world_size
,
worker_id
=
self
.
worker_id
,
chunk_hash
=
block_hash
,
)
class
LMCacheMPWorkerAdapter
:
def
__init__
(
self
,
server_url
:
str
,
context
:
zmq
.
Context
,
model_name
:
str
,
world_size
:
int
,
kv_rank
:
int
,
vllm_block_size
:
int
,
):
self
.
mq_client
=
MessageQueueClient
(
server_url
,
context
)
# Instance id for GPU worker
self
.
instance_id
=
os
.
getpid
()
# Registered kv caches from vLLM
self
.
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
# Request futures
# request_id -> (future, other merged requests)
self
.
store_futures
:
dict
[
str
,
tuple
[
MessagingFuture
[
StoreResult
],
list
[
str
]]
]
=
{}
self
.
retrieve_futures
:
dict
[
str
,
tuple
[
MessagingFuture
[
RetrieveResult
],
list
[
str
]]
]
=
{}
self
.
finished_stores
:
set
[
str
]
=
set
()
self
.
previously_finished
:
set
[
str
]
=
set
()
self
.
model_name
=
model_name
self
.
world_size
=
world_size
self
.
worker_id
=
kv_rank
# Read chunk size from lmcache
chunk_size
=
get_lmcache_chunk_size
(
self
.
mq_client
)
assert
chunk_size
%
vllm_block_size
==
0
,
(
"LMCache chunk size should be a multiple of vLLM block size"
)
self
.
blocks_in_chunk
=
chunk_size
//
vllm_block_size
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
KVCache
]):
# Register kv cache and send the request
self
.
kv_caches
=
kv_caches
logger
.
info
(
"Registering kv caches"
)
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
REGISTER_KV_CACHE
,
[
self
.
instance_id
,
wrap_kv_caches
(
kv_caches
)],
)
future
.
result
()
@
_lmcache_nvtx_annotate
def
submit_store_request
(
self
,
request_id
:
str
,
op
:
LoadStoreOp
,
event
:
torch
.
cuda
.
Event
):
keys
=
self
.
_block_hashes_to_keys
(
op
.
block_hashes
)
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
STORE
,
[
keys
,
self
.
instance_id
,
op
.
block_ids
,
event
.
ipc_handle
()],
).
to_cuda_future
()
self
.
store_futures
[
request_id
]
=
(
future
,
[])
@
_lmcache_nvtx_annotate
def
submit_retrieve_request
(
self
,
request_id
:
str
,
op
:
LoadStoreOp
,
event
:
torch
.
cuda
.
Event
):
keys
=
self
.
_block_hashes_to_keys
(
op
.
block_hashes
)
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
RETRIEVE
,
[
keys
,
self
.
instance_id
,
op
.
block_ids
,
event
.
ipc_handle
()],
).
to_cuda_future
()
self
.
retrieve_futures
[
request_id
]
=
(
future
,
[])
@
_lmcache_nvtx_annotate
def
batched_submit_store_requests
(
self
,
request_ids
:
list
[
str
],
ops
:
list
[
LoadStoreOp
],
event
:
torch
.
cuda
.
Event
,
):
keys
=
[]
block_ids
=
[]
for
op
in
ops
:
keys
.
extend
(
self
.
_block_hashes_to_keys
(
op
.
block_hashes
))
block_ids
.
extend
(
op
.
block_ids
)
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
STORE
,
[
keys
,
self
.
instance_id
,
block_ids
,
event
.
ipc_handle
()],
).
to_cuda_future
()
self
.
store_futures
[
request_ids
[
0
]]
=
(
future
,
request_ids
[
1
:])
@
_lmcache_nvtx_annotate
def
batched_submit_retrieve_requests
(
self
,
request_ids
:
list
[
str
],
ops
:
list
[
LoadStoreOp
],
event
:
torch
.
cuda
.
Event
,
):
keys
=
[]
block_ids
=
[]
for
op
in
ops
:
keys
.
extend
(
self
.
_block_hashes_to_keys
(
op
.
block_hashes
))
block_ids
.
extend
(
op
.
block_ids
)
future
=
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
RETRIEVE
,
[
keys
,
self
.
instance_id
,
block_ids
,
event
.
ipc_handle
()],
).
to_cuda_future
()
self
.
retrieve_futures
[
request_ids
[
0
]]
=
(
future
,
request_ids
[
1
:])
@
_lmcache_nvtx_annotate
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
]
)
->
tuple
[
set
[
str
]
|
None
,
set
[
str
]
|
None
]:
finished_stores
=
set
()
finished_retrieves
=
set
()
for
request_id
,
(
future
,
other_reqs
)
in
self
.
store_futures
.
items
():
if
not
future
.
query
():
continue
result
=
future
.
result
()
finished_stores
.
add
(
request_id
)
finished_stores
.
update
(
other_reqs
)
if
not
result
:
# TODO: add error handling here
logger
.
error
(
"Something went wrong when processing the "
"store request for request_id=%s"
,
request_id
,
)
for
request_id
,
(
future
,
other_reqs
)
in
self
.
retrieve_futures
.
items
():
if
not
future
.
query
():
continue
result
=
future
.
result
()
finished_retrieves
.
add
(
request_id
)
finished_retrieves
.
update
(
other_reqs
)
if
not
all
(
result
):
# TODO: add error handing here
logger
.
error
(
"Something went wrong when processing the "
"retrieve request for request_id=%s, result=%s"
,
request_id
,
result
,
)
logger
.
info
(
"Retrieve request for request_id=%s finished"
,
request_id
)
# Remove the finished requests from the tracking dicts
for
request_id
in
finished_stores
:
self
.
store_futures
.
pop
(
request_id
,
None
)
for
request_id
in
finished_retrieves
:
self
.
retrieve_futures
.
pop
(
request_id
,
None
)
# Update the internal states
self
.
finished_stores
.
update
(
finished_stores
)
ret_stores
=
set
()
for
req_id
in
finished_req_ids
:
if
req_id
in
self
.
finished_stores
or
req_id
in
self
.
store_futures
:
self
.
previously_finished
.
add
(
req_id
)
else
:
ret_stores
.
add
(
req_id
)
# Calculate the final finished stores
ret_stores
.
update
(
self
.
_update_and_get_finished_store
())
return
ret_stores
,
finished_retrieves
def
num_blocks_per_chunk
(
self
)
->
int
:
"""
Returns:
The number of vllm blocks in a LMCache data chunk
"""
return
self
.
blocks_in_chunk
def
shutdown
(
self
):
# Unregister kv cache
logger
.
info
(
"Unregistering kv caches"
)
send_lmcache_request
(
self
.
mq_client
,
RequestType
.
UNREGISTER_KV_CACHE
,
[
self
.
instance_id
]
).
result
()
self
.
mq_client
.
close
()
# Helper functions
def
_update_and_get_finished_store
(
self
,
)
->
set
[
str
]:
"""Converge the internal states about finished stores
and returns the 'safe finished store request ids' back
"""
safe_finished_s
=
self
.
finished_stores
.
intersection
(
self
.
previously_finished
)
self
.
finished_stores
.
difference_update
(
self
.
previously_finished
)
self
.
previously_finished
.
difference_update
(
safe_finished_s
)
return
safe_finished_s
def
_create_key
(
self
,
block_hash
:
bytes
)
->
IPCCacheEngineKey
:
"""Convert a block hash to an IPC cache engine key"""
return
IPCCacheEngineKey
(
model_name
=
self
.
model_name
,
world_size
=
self
.
world_size
,
worker_id
=
self
.
worker_id
,
chunk_hash
=
block_hash
,
)
def
_block_hashes_to_keys
(
self
,
block_hashes
:
list
[
bytes
]
)
->
list
[
IPCCacheEngineKey
]:
"""Convert block hashes to IPC cache engine keys"""
s
=
striding_block_hashes
(
block_hashes
,
self
.
blocks_in_chunk
)
return
[
self
.
_create_key
(
block_hash
)
for
block_hash
in
s
]
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py
0 → 100644
View file @
94a9ebcf
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
enum
from
collections.abc
import
Iterable
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Literal
,
Optional
,
cast
import
torch
import
zmq
from
lmcache.utils
import
init_logger
as
lmcache_init_logger
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration
import
(
LMCacheMPSchedulerAdapter
,
LMCacheMPWorkerAdapter
,
LoadStoreOp
,
)
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.v1.outputs
import
KVConnectorOutput
from
vllm.v1.utils
import
ConstantList
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_events
import
KVCacheEvent
from
vllm.distributed.kv_transfer.kv_connector.v1.metrics
import
(
KVConnectorPromMetrics
,
KVConnectorStats
,
PromMetric
,
PromMetricT
,
)
from
vllm.forward_context
import
ForwardContext
from
vllm.v1.core.kv_cache_manager
import
KVCacheBlocks
from
vllm.v1.core.kv_cache_utils
import
BlockHash
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.request
import
Request
logger
=
lmcache_init_logger
(
__name__
)
# Helper functions
def
reformat_block_ids
(
block_ids
:
tuple
[
list
[
int
],
...]
|
None
)
->
list
[
int
]:
if
block_ids
is
None
:
return
[]
assert
isinstance
(
block_ids
,
tuple
),
(
f
"Expected block_ids to be a tuple of lists, but got
{
type
(
block_ids
)
}
"
)
if
len
(
block_ids
)
>
1
:
raise
RuntimeError
(
"LMCacheMPConnector only works without hybrid kv cache manager. "
"Please pass --disable-hybrid-kv-cache-manager when starting vllm"
)
return
block_ids
[
0
]
def
create_scheduler_adapter
(
server_url
:
str
,
zmq_context
:
zmq
.
Context
,
vllm_config
:
VllmConfig
)
->
LMCacheMPSchedulerAdapter
:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
return
LMCacheMPSchedulerAdapter
(
server_url
,
zmq_context
,
vllm_config
.
model_config
.
model
,
vllm_config
.
parallel_config
.
world_size
,
vllm_config
.
parallel_config
.
rank
,
vllm_config
.
cache_config
.
block_size
,
)
def
create_worker_adapter
(
server_url
:
str
,
zmq_context
:
zmq
.
Context
,
vllm_config
:
VllmConfig
)
->
LMCacheMPWorkerAdapter
:
# TODO: have a helper function to calculate the correct rank and
# world size for the MLA and other models
return
LMCacheMPWorkerAdapter
(
server_url
,
zmq_context
,
vllm_config
.
model_config
.
model
,
vllm_config
.
parallel_config
.
world_size
,
vllm_config
.
parallel_config
.
rank
,
vllm_config
.
cache_config
.
block_size
,
)
def
convert_block_hashes_to_bytes
(
block_hashes
:
list
[
"BlockHash"
],
)
->
list
[
bytes
]:
return
cast
(
list
[
bytes
],
block_hashes
)
class
LMCacheMPRequestState
(
enum
.
Enum
):
"""
State machine:
PREFETCHING -- update_state_after_alloc --> WAITING_FOR_LOAD
WAITING_FOR_LOAD -- process_loading_requests --> READY
"""
PREFETCHING
=
enum
.
auto
()
WAITING_FOR_LOAD
=
enum
.
auto
()
READY
=
enum
.
auto
()
@
dataclass
class
LMCacheMPRequestTracker
:
# NOTE: this class used vLLM data structures, should be part of
# vLLM integration code
request_id
:
str
# Read-only lists to track the token ids and block hashes
all_token_ids
:
ConstantList
[
int
]
block_hashes
:
ConstantList
[
"BlockHash"
]
# Block ids and hashes will be updated at update_states_after_alloc and
# during the generation
allocated_block_ids
:
list
[
int
]
=
field
(
default_factory
=
list
)
# Number of scheduled tokens in this request. We keep tracking this to
# avoid saving half-full blocks.
num_scheduled_tokens
:
int
=
0
# Number of blocks stored will be initialized when lookup the external
# hit tokens and will be updated when processing new requests and cached
# requests.
num_stored_blocks
:
int
=
0
# Staging load operation -- save vllm and lmcache hit tokens during lookup
num_vllm_hit_blocks
:
int
=
0
num_lmcache_hit_blocks
:
int
=
0
# Main state
state
:
LMCacheMPRequestState
=
LMCacheMPRequestState
.
PREFETCHING
def
__init__
(
self
,
request
:
"Request"
):
self
.
request_id
=
request
.
request_id
self
.
all_token_ids
=
request
.
all_token_ids
self
.
block_hashes
=
ConstantList
(
request
.
block_hashes
)
self
.
allocated_block_ids
=
[]
self
.
num_stored_blocks
=
0
self
.
num_vllm_hit_blocks
=
0
self
.
num_lmcache_hit_blocks
=
0
self
.
state
=
LMCacheMPRequestState
.
PREFETCHING
####
# Check the state of the request
####
def
needs_retrieve
(
self
)
->
bool
:
"""Check whether the current request needs retrieve, will be used
update_stage_after_alloc"""
return
(
self
.
num_lmcache_hit_blocks
>
self
.
num_vllm_hit_blocks
and
self
.
state
!=
LMCacheMPRequestState
.
READY
)
def
is_ready_for_retrieving
(
self
)
->
bool
:
"""Check whether the current request is ready for retrieving,
will be used in process_loading_requests"""
return
(
self
.
state
==
LMCacheMPRequestState
.
WAITING_FOR_LOAD
and
self
.
needs_retrieve
()
)
####
# Update internal states
####
def
increase_num_scheduled_tokens
(
self
,
num_new_tokens
:
int
):
self
.
num_scheduled_tokens
+=
num_new_tokens
def
increase_num_stored_blocks
(
self
,
num_new_blocks
:
int
):
"""Increase the number of stored blocks for the current request
This function will be called when processing the cached requests.
"""
self
.
num_stored_blocks
+=
num_new_blocks
def
update_block_ids
(
self
,
new_block_ids
:
list
[
int
],
):
"""Update the block ids for the current request
This function will be called when processing the cached requests.
"""
self
.
allocated_block_ids
.
extend
(
new_block_ids
)
####
# For debugging
####
def
__repr__
(
self
)
->
str
:
return
(
f
"LMCacheMPRequestTracker(request_id=
{
self
.
request_id
}
, "
f
"num_tokens=
{
len
(
self
.
all_token_ids
)
}
, "
f
"num_block_hashes=
{
len
(
self
.
block_hashes
)
}
, "
f
"num_allocated_blocks=
{
len
(
self
.
allocated_block_ids
)
}
, "
f
"num_stored_blocks=
{
self
.
num_stored_blocks
}
, "
f
"vllm_hit_blocks=
{
self
.
num_vllm_hit_blocks
}
, "
f
"lmcache_hit_blocks=
{
self
.
num_lmcache_hit_blocks
}
, "
f
"state=
{
self
.
state
}
)"
)
def
__str__
(
self
)
->
str
:
return
self
.
__repr__
()
@
dataclass
class
LMCacheMPRequestMetadata
:
request_id
:
str
direction
:
Literal
[
"STORE"
,
"RETRIEVE"
]
op
:
LoadStoreOp
@
staticmethod
def
GetStoreMetadata
(
tracker
:
LMCacheMPRequestTracker
,
blocks_in_chunk
:
int
,
vllm_block_size
:
int
,
)
->
"LMCacheMPRequestMetadata | None"
:
"""
Generate the store metadata for the current request tracker.
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
"""
# Store the blocks that has block hashes
# NOTE: the invariant here is that `num_stored_blocks` should
# always be a multiple of `blocks_in_chunk`
# TODO: This should be checked everytime we update the num_stored_blocks
min_available_blocks
=
min
(
len
(
tracker
.
block_hashes
),
len
(
tracker
.
allocated_block_ids
),
tracker
.
num_scheduled_tokens
//
vllm_block_size
,
)
num_staging_blocks
=
min_available_blocks
-
tracker
.
num_stored_blocks
num_chunks
=
num_staging_blocks
//
blocks_in_chunk
if
num_chunks
>=
1
:
start
=
tracker
.
num_stored_blocks
end
=
start
+
num_chunks
*
blocks_in_chunk
block_hashes
=
convert_block_hashes_to_bytes
(
tracker
.
block_hashes
[
start
:
end
]
)
block_ids
=
tracker
.
allocated_block_ids
[
start
:
end
]
ret
=
LMCacheMPRequestMetadata
(
request_id
=
tracker
.
request_id
,
direction
=
"STORE"
,
op
=
LoadStoreOp
(
block_hashes
=
block_hashes
,
block_ids
=
block_ids
),
)
# Update the request tracker
tracker
.
increase_num_stored_blocks
(
end
-
start
)
return
ret
return
None
@
staticmethod
def
GetRetrieveMetadata
(
tracker
:
LMCacheMPRequestTracker
,
blocks_in_chunk
:
int
,
)
->
"LMCacheMPRequestMetadata | None"
:
"""
Generate the retrieve metadata for the current request tracker.
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
"""
if
not
tracker
.
is_ready_for_retrieving
():
return
None
# |---------------------|-----------------|----------------|
# | num_vllm_hit_blocks |
# | lmcache chunk 1 | lmcache chunk 2 |
# | need to retrieve |
start
=
tracker
.
num_vllm_hit_blocks
//
blocks_in_chunk
*
blocks_in_chunk
end
=
tracker
.
num_lmcache_hit_blocks
assert
end
%
blocks_in_chunk
==
0
,
(
"The number of LMCache hit blocks should be a multiple of the "
"number of blocks in a lmcache chunk. "
)
assert
len
(
tracker
.
block_hashes
)
>=
end
,
(
"The number of block hashes should be greater than or equal to the "
"number of LMCache hit blocks. "
)
if
end
>
start
:
block_hashes
=
convert_block_hashes_to_bytes
(
tracker
.
block_hashes
[
start
:
end
]
)
block_ids
=
tracker
.
allocated_block_ids
[
start
:
end
]
ret
=
LMCacheMPRequestMetadata
(
request_id
=
tracker
.
request_id
,
direction
=
"RETRIEVE"
,
op
=
LoadStoreOp
(
block_hashes
=
block_hashes
,
block_ids
=
block_ids
),
)
return
ret
return
None
class
LMCacheMPConnectorMetadata
(
KVConnectorMetadata
):
def
__init__
(
self
):
super
().
__init__
()
self
.
requests
:
list
[
LMCacheMPRequestMetadata
]
=
[]
def
add_request_metadata
(
self
,
request_metadata
:
LMCacheMPRequestMetadata
):
self
.
requests
.
append
(
request_metadata
)
def
__len__
(
self
):
return
len
(
self
.
requests
)
# For debugging
def
__str__
(
self
):
request_strs
=
[]
for
req_meta
in
self
.
requests
:
request_strs
.
append
(
f
"RequestMetadata(request_id=
{
req_meta
.
request_id
}
, "
f
"direction=
{
req_meta
.
direction
}
, "
f
"num_blocks=
{
len
(
req_meta
.
op
)
}
, "
f
"block_ids=
{
req_meta
.
op
.
block_ids
}
)"
)
return
"["
+
"
\n
"
.
join
(
request_strs
)
+
"]"
def
__repr__
(
self
):
return
self
.
__str__
()
class
LMCacheMPConnector
(
KVConnectorBase_V1
):
"""
The connector for LMCache multi-process mode.
Extra configs (kv_transfer_config.extra_config):
- lmcache.mp.host: the host of the LMCache server.
- lmcache.mp.port: the port of the LMCache server.
"""
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
kv_cache_config
:
Optional
[
"KVCacheConfig"
]
=
None
,
):
super
().
__init__
(
vllm_config
,
role
,
kv_cache_config
)
assert
vllm_config
.
kv_transfer_config
is
not
None
server_host
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"lmcache.mp.host"
,
"tcp://localhost"
)
server_port
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"lmcache.mp.port"
,
5555
)
server_url
=
f
"
{
server_host
}
:
{
server_port
}
"
zmq_context
=
zmq
.
Context
.
instance
()
if
self
.
role
==
KVConnectorRole
.
SCHEDULER
:
self
.
scheduler_adapter
=
create_scheduler_adapter
(
server_url
,
zmq_context
,
vllm_config
)
self
.
request_trackers
:
dict
[
str
,
LMCacheMPRequestTracker
]
=
{}
elif
self
.
role
==
KVConnectorRole
.
WORKER
:
self
.
worker_adapter
=
create_worker_adapter
(
server_url
,
zmq_context
,
vllm_config
)
else
:
raise
ValueError
(
f
"Unknown KVConnectorRole:
{
self
.
role
}
"
)
self
.
vllm_block_size
=
vllm_config
.
cache_config
.
block_size
@
property
def
role
(
self
)
->
KVConnectorRole
:
return
self
.
_role
# ==============================
# Worker-side methods
# ==============================
def
_get_connector_metadata
(
self
)
->
KVConnectorMetadata
:
"""Get the connector metadata.
This function should only be called inside the connector.
Returns:
ConnectorMetadata: the connector metadata.
"""
# Should only be called while set to valid metadata.
assert
self
.
_connector_metadata
is
not
None
return
self
.
_connector_metadata
def
register_kv_caches
(
self
,
kv_caches
:
dict
[
str
,
torch
.
Tensor
]):
"""
Initialize with the KV caches. Useful for pre-registering the
KV Caches in the KVConnector (e.g. for NIXL).
Args:
kv_caches: dictionary of layer names, kv cache
"""
logger
.
info
(
"Registering kv caches!"
)
self
.
worker_adapter
.
register_kv_caches
(
kv_caches
)
return
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
:
Any
)
->
None
:
"""
Start loading the KV cache from the connector to vLLM's paged
KV buffer. This is called from the forward context before the
forward pass to enable async loading during model execution.
Args:
forward_context (ForwardContext): the forward context.
**kwargs: additional arguments for the load operation
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
LMCacheMPConnectorMetadata
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
event
=
torch
.
cuda
.
Event
(
interprocess
=
True
)
event
.
record
()
request_ids
=
[]
ops
=
[]
for
meta
in
metadata
.
requests
:
if
meta
.
direction
!=
"RETRIEVE"
:
continue
request_ids
.
append
(
meta
.
request_id
)
ops
.
append
(
meta
.
op
)
if
len
(
request_ids
)
>
0
:
logger
.
info
(
"HERE! SUBMITTING THE BATCHED RETRIEVE REQUESTS %s"
,
request_ids
)
self
.
worker_adapter
.
batched_submit_retrieve_requests
(
request_ids
,
ops
,
event
)
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""
Block until the KV for a specific layer is loaded into vLLM's
paged buffer. This is called from within attention layer to ensure
async copying from start_load_kv is complete.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
return
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"AttentionMetadata"
,
**
kwargs
:
Any
,
)
->
None
:
"""
Start saving a layer of KV cache from vLLM's paged buffer
to the connector. This is called from within attention layer to
enable async copying during execution.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
**kwargs: additional arguments for the save operation.
"""
return
def
wait_for_save
(
self
):
"""
Block until all the save operations is done. This is called
as the forward context exits to ensure that the async saving
from save_kv_layer is complete before finishing the forward.
This prevents overwrites of paged KV buffer before saving done.
"""
metadata
=
self
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
LMCacheMPConnectorMetadata
)
with
torch
.
cuda
.
stream
(
torch
.
cuda
.
current_stream
()):
event
=
torch
.
cuda
.
Event
(
interprocess
=
True
)
event
.
record
()
request_ids
=
[]
ops
=
[]
for
meta
in
metadata
.
requests
:
if
meta
.
direction
!=
"STORE"
:
continue
request_ids
.
append
(
meta
.
request_id
)
ops
.
append
(
meta
.
op
)
if
len
(
request_ids
)
>
0
:
self
.
worker_adapter
.
batched_submit_store_requests
(
request_ids
,
ops
,
event
)
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
]
)
->
tuple
[
set
[
str
]
|
None
,
set
[
str
]
|
None
]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens on the worker.
The scheduler process (via the Executors) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer
(requests that previously returned True from request_finished()),
tuple of (sending/saving ids, recving/loading ids).
The finished saves/sends req ids must belong to a set provided in a
call to this method (this call or a prior one).
"""
val
=
self
.
worker_adapter
.
get_finished
(
finished_req_ids
)
# logger.error("Finished req ids: %s, %s", val[0], val[1])
return
val
def
get_block_ids_with_load_errors
(
self
)
->
set
[
int
]:
"""
Get the set of block IDs that failed to load.
Returns:
Set of block IDs that encountered load errors.
Empty set if no load errors occurred.
Notes:
- Applies to both sync- and async-loading requests.
- Async loading: failed blocks may be reported in any forward pass
up to and including the pass where the request ID is returned by
`get_finished()`. Even if failures occur, the request must still
be reported via `get_finished()`, and the failed block IDs must
appear here no later than that same pass.
- Sync loading: failed blocks should be reported in the forward
pass in which they are detected.
"""
# TODO: add error tracking
return
set
()
def
shutdown
(
self
):
"""
Shutdown the connector. This is called when the worker process
is shutting down to ensure that all the async operations are
completed and the connector is cleaned up properly.
"""
if
hasattr
(
self
,
"worker_adapter"
):
self
.
worker_adapter
.
shutdown
()
return
None
def
get_kv_connector_stats
(
self
)
->
Optional
[
"KVConnectorStats"
]:
"""
Get the KV connector stats collected during the last interval.
"""
return
None
# ==============================
# Scheduler-side methods
# ==============================
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
tuple
[
int
|
None
,
bool
]:
"""
Get number of new tokens that can be loaded from the
external KV cache beyond the num_computed_tokens.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
A tuple with the following elements:
- An optional number of tokens that can be loaded from the
external KV cache beyond what is already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if external KV cache tokens will be loaded
asynchronously (between scheduler steps). Must be
'False' if the first element is 0.
Notes:
The connector should only consider the largest prefix of prompt-
tokens for which KV cache is actually available at the time of the
call. If the cache cannot be loaded for some tokens (e.g., due to
connectivity issues or eviction), those tokens must not be taken
into account.
"""
tracker
=
self
.
_get_or_create_request_tracker
(
request
)
self
.
scheduler_adapter
.
maybe_submit_lookup_request
(
request
.
request_id
,
convert_block_hashes_to_bytes
(
request
.
block_hashes
)
)
ret
=
self
.
scheduler_adapter
.
check_lookup_result
(
request
.
request_id
)
if
ret
is
None
:
return
None
,
True
if
ret
==
0
:
return
0
,
False
assert
(
ret
%
(
self
.
scheduler_adapter
.
num_blocks_per_chunk
()
*
self
.
vllm_block_size
)
==
0
)
# Update num stored blocks for the tracker
num_vllm_blocks
=
num_computed_tokens
//
self
.
vllm_block_size
num_lmcache_blocks
=
ret
//
self
.
vllm_block_size
tracker
.
increase_num_stored_blocks
(
num_lmcache_blocks
)
# Save the vllm and lmcache hit tokens
tracker
.
num_vllm_hit_blocks
=
num_vllm_blocks
tracker
.
num_lmcache_hit_blocks
=
num_lmcache_blocks
need_to_load
=
max
(
0
,
ret
-
num_computed_tokens
)
logger
.
debug
(
"vLLM hit is: %d, Need to load is %d"
,
num_computed_tokens
,
need_to_load
)
return
need_to_load
,
need_to_load
>
0
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
blocks
:
"KVCacheBlocks"
,
num_external_tokens
:
int
):
"""
Update KVConnector state after block allocation.
If get_num_new_matched_tokens previously returned True for a
request, this function may be called twice for that same request -
first when blocks are allocated for the connector tokens to be
asynchronously loaded into, and second when any additional blocks
are allocated, after the load/transfer is complete.
Args:
request (Request): the request object.
blocks (KVCacheBlocks): the blocks allocated for the request.
num_external_tokens (int): the number of tokens that will be
loaded from the external KV cache.
"""
# NOTE: the `blocks` are NEW BLOCKS allocated for this request.
tracker
=
self
.
_get_request_tracker
(
request
.
request_id
)
block_ids
=
reformat_block_ids
(
blocks
.
get_block_ids
())
# No matter we need to retrieve or not, we need to update
# the block ids into the tracker
tracker
.
update_block_ids
(
block_ids
)
# Update the state of the tracker
condition
=
tracker
.
needs_retrieve
()
if
tracker
.
state
==
LMCacheMPRequestState
.
PREFETCHING
:
# If need to retrieve, change to WAITING_FOR_LOAD
# Otherwise, change to READY
tracker
.
state
=
(
LMCacheMPRequestState
.
WAITING_FOR_LOAD
if
condition
else
LMCacheMPRequestState
.
READY
)
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
"""
Build the connector metadata for this step.
This function should NOT modify fields in the scheduler_output.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
metadata
=
LMCacheMPConnectorMetadata
()
self
.
_process_retrieve_requests
(
metadata
)
self
.
_process_new_requests
(
scheduler_output
,
metadata
)
self
.
_process_cached_requests
(
scheduler_output
,
metadata
)
if
len
(
metadata
)
>
0
:
logger
.
debug
(
"Final connector metadata: %s"
,
metadata
)
return
metadata
def
update_connector_output
(
self
,
connector_output
:
KVConnectorOutput
):
"""
Update KVConnector state from worker-side connectors output.
Args:
connector_output (KVConnectorOutput): the worker-side
connectors output.
"""
return
def
request_finished
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
"""
Called exactly once when a request has finished, before its blocks are
freed.
The connector may assumes responsibility for freeing the blocks
asynchronously by returning True.
Returns:
True if the request is being saved/sent asynchronously and blocks
should not be freed until the request_id is returned from
get_finished().
Optional KVTransferParams to be included in the request outputs
returned by the engine.
"""
return
True
,
None
def
take_events
(
self
)
->
Iterable
[
"KVCacheEvent"
]:
"""
Take the KV cache events from the connector.
Yields:
New KV cache events since the last call.
"""
return
()
@
classmethod
def
get_required_kvcache_layout
(
cls
,
vllm_config
:
"VllmConfig"
)
->
str
|
None
:
"""
Get the required KV cache layout for this connector.
Args:
vllm_config (VllmConfig): the vllm config.
Returns:
str: the required KV cache layout. e.g. HND, or NHD.
None if the connector does not require a specific layout.
"""
if
cls
is
KVConnectorBase_V1
:
raise
TypeError
(
"get_required_kvcache_layout should not be called "
"on the abstract base class"
)
return
None
def
get_finished_count
(
self
)
->
int
|
None
:
"""
Get the count of requests expected to complete send/receive operations
via this connector. This method is used to initialize the
KVOutputAggregator, overwriting the default world_size.
Returns:
int: expected sending or receiving completion count.
"""
return
None
@
classmethod
def
build_kv_connector_stats
(
cls
,
data
:
dict
[
str
,
Any
]
|
None
=
None
)
->
Optional
[
"KVConnectorStats"
]:
"""
KVConnectorStats resolution method. This method allows dynamically
registered connectors to return their own KVConnectorStats object,
which can implement custom aggregation logic on the data dict.
"""
return
None
@
classmethod
def
build_prom_metrics
(
cls
,
vllm_config
:
"VllmConfig"
,
metric_types
:
dict
[
type
[
"PromMetric"
],
type
[
"PromMetricT"
]],
labelnames
:
list
[
str
],
per_engine_labelvalues
:
dict
[
int
,
list
[
str
]],
)
->
Optional
[
"KVConnectorPromMetrics"
]:
"""
Create a KVConnectorPromMetrics subclass which should register
per-connector Prometheus metrics and implement observe() to
expose connector transfer stats via Prometheus.
"""
return
None
##############################
# Helper functions
##############################
def
_process_retrieve_requests
(
self
,
metadata
:
LMCacheMPConnectorMetadata
,
)
->
None
:
blocks_per_chunk
=
self
.
scheduler_adapter
.
num_blocks_per_chunk
()
for
request_tracker
in
self
.
request_trackers
.
values
():
if
request_tracker
.
state
!=
LMCacheMPRequestState
.
WAITING_FOR_LOAD
:
continue
r_metadata
=
LMCacheMPRequestMetadata
.
GetRetrieveMetadata
(
request_tracker
,
blocks_per_chunk
)
if
r_metadata
is
not
None
:
metadata
.
add_request_metadata
(
r_metadata
)
request_tracker
.
state
=
LMCacheMPRequestState
.
READY
def
_process_new_requests
(
self
,
scheduler_output
:
SchedulerOutput
,
metadata
:
LMCacheMPConnectorMetadata
,
)
->
None
:
blocks_per_chunk
=
self
.
scheduler_adapter
.
num_blocks_per_chunk
()
for
new_request
in
scheduler_output
.
scheduled_new_reqs
:
request_tracker
=
self
.
_get_request_tracker
(
new_request
.
req_id
)
num_new_tokens
=
scheduler_output
.
num_scheduled_tokens
[
new_request
.
req_id
]
request_tracker
.
increase_num_scheduled_tokens
(
num_new_tokens
)
r_meta
=
LMCacheMPRequestMetadata
.
GetStoreMetadata
(
request_tracker
,
blocks_per_chunk
,
self
.
vllm_block_size
)
if
r_meta
is
not
None
:
metadata
.
add_request_metadata
(
r_meta
)
def
_process_cached_requests
(
self
,
scheduler_output
:
SchedulerOutput
,
metadata
:
LMCacheMPConnectorMetadata
,
)
->
None
:
blocks_per_chunk
=
self
.
scheduler_adapter
.
num_blocks_per_chunk
()
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
for
idx
,
request_id
in
enumerate
(
cached_reqs
.
req_ids
):
request_tracker
=
self
.
_get_request_tracker
(
request_id
)
# Update block ids
new_block_ids
=
reformat_block_ids
(
cached_reqs
.
new_block_ids
[
idx
])
request_tracker
.
update_block_ids
(
new_block_ids
)
# Update new scheduled tokens
num_new_tokens
=
cached_reqs
.
num_computed_tokens
[
idx
]
request_tracker
.
increase_num_scheduled_tokens
(
num_new_tokens
)
r_meta
=
LMCacheMPRequestMetadata
.
GetStoreMetadata
(
request_tracker
,
blocks_per_chunk
,
self
.
vllm_block_size
)
if
r_meta
is
not
None
:
metadata
.
add_request_metadata
(
r_meta
)
def
_get_request_tracker
(
self
,
request_id
:
str
)
->
LMCacheMPRequestTracker
:
assert
request_id
in
self
.
request_trackers
,
(
f
"Request tracker for request_id
{
request_id
}
not found. "
)
return
self
.
request_trackers
[
request_id
]
def
_get_or_create_request_tracker
(
self
,
request
:
"Request"
)
->
LMCacheMPRequestTracker
:
request_id
=
request
.
request_id
if
request_id
not
in
self
.
request_trackers
:
new_tracker
=
LMCacheMPRequestTracker
(
request
)
self
.
request_trackers
[
request_id
]
=
new_tracker
return
self
.
request_trackers
[
request_id
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment