Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
83f478bb
Unverified
Commit
83f478bb
authored
Oct 24, 2025
by
Yihua Cheng
Committed by
GitHub
Oct 25, 2025
Browse files
[KVConnector] Migrate the LMCache integration code to be vLLM native (#25542)
Signed-off-by:
ApostaC
<
yihua98@uchicago.edu
>
parent
269c4db0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1637 additions
and
2 deletions
+1637
-2
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
...tributed/kv_transfer/kv_connector/v1/lmcache_connector.py
+18
-2
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py
..._transfer/kv_connector/v1/lmcache_integration/__init__.py
+2
-0
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py
.../kv_transfer/kv_connector/v1/lmcache_integration/utils.py
+221
-0
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
...er/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
+1396
-0
No files found.
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py
View file @
83f478bb
...
...
@@ -3,7 +3,9 @@
from
typing
import
TYPE_CHECKING
,
Any
import
torch
from
lmcache.integration.vllm.vllm_v1_adapter
import
LMCacheConnectorV1Impl
from
lmcache.integration.vllm.vllm_v1_adapter
import
(
LMCacheConnectorV1Impl
as
LMCacheConnectorLatestImpl
,
)
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
...
...
@@ -11,6 +13,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration
import
(
vllm_v1_adapter
as
_adapter
,
)
from
vllm.logger
import
init_logger
from
vllm.v1.core.sched.output
import
SchedulerOutput
...
...
@@ -26,7 +31,18 @@ logger = init_logger(__name__)
class
LMCacheConnectorV1
(
KVConnectorBase_V1
):
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
):
super
().
__init__
(
vllm_config
=
vllm_config
,
role
=
role
)
self
.
_lmcache_engine
=
LMCacheConnectorV1Impl
(
vllm_config
,
role
,
self
)
assert
vllm_config
.
kv_transfer_config
is
not
None
use_native
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"use_native"
,
False
)
if
use_native
:
logger
.
info
(
"Initializing native LMCache connector"
)
cls
=
_adapter
.
LMCacheConnectorV1Impl
else
:
logger
.
info
(
"Initializing latest dev LMCache connector"
)
cls
=
LMCacheConnectorLatestImpl
self
.
_lmcache_engine
=
cls
(
vllm_config
,
role
,
self
)
# ==============================
# Worker-side methods
...
...
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py
0 → 100644
View file @
83f478bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py
0 → 100644
View file @
83f478bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Standard
import
os
import
threading
from
typing
import
TYPE_CHECKING
,
Union
import
torch
from
lmcache.config
import
LMCacheEngineConfig
as
Config
from
lmcache.logging
import
init_logger
from
lmcache.v1.config
import
LMCacheEngineConfig
as
V1Config
if
TYPE_CHECKING
:
from
vllm.config
import
ModelConfig
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
ENGINE_NAME
=
"vllm-instance"
# Thread-safe singleton storage
_config_instance
:
Config
|
V1Config
|
None
=
None
_config_lock
=
threading
.
Lock
()
def
is_false
(
value
:
str
)
->
bool
:
"""Check if the given string value is equivalent to 'false'."""
return
value
.
lower
()
in
(
"false"
,
"0"
,
"no"
,
"n"
,
"off"
)
def
lmcache_get_or_create_config
()
->
Config
|
V1Config
:
"""Get the LMCache configuration from the environment variable
`LMCACHE_CONFIG_FILE`. If the environment variable is not set, this
function will return the default configuration.
This function is thread-safe and implements singleton pattern,
ensuring the configuration is loaded only once.
"""
global
_config_instance
# Double-checked locking for thread-safe singleton
if
_config_instance
is
None
:
with
_config_lock
:
if
_config_instance
is
None
:
# Check again within lock
if
is_false
(
os
.
getenv
(
"LMCACHE_USE_EXPERIMENTAL"
,
"True"
)):
logger
.
warning
(
"Detected LMCACHE_USE_EXPERIMENTAL is set to False. "
"Using legacy configuration is deprecated and will "
"be remove soon! Please set LMCACHE_USE_EXPERIMENTAL "
"to True."
)
LMCacheEngineConfig
=
Config
# type: ignore[assignment]
else
:
LMCacheEngineConfig
=
V1Config
# type: ignore[assignment]
if
"LMCACHE_CONFIG_FILE"
not
in
os
.
environ
:
logger
.
warning
(
"No LMCache configuration file is set. Trying to read"
" configurations from the environment variables."
)
logger
.
warning
(
"You can set the configuration file through "
"the environment variable: LMCACHE_CONFIG_FILE"
)
_config_instance
=
LMCacheEngineConfig
.
from_env
()
else
:
config_file
=
os
.
environ
[
"LMCACHE_CONFIG_FILE"
]
logger
.
info
(
"Loading LMCache config file %s"
,
config_file
)
_config_instance
=
LMCacheEngineConfig
.
from_file
(
config_file
)
# Update config from environment variables
_config_instance
.
update_config_from_env
()
return
_config_instance
def
hex_hash_to_int16
(
s
:
str
)
->
int
:
"""
Convert a hex hash string to a 16-bit integer.
"""
return
int
(
s
,
16
)
&
0xFFFF
def
apply_mm_hashes_to_token_ids
(
token_ids
:
torch
.
Tensor
,
mm_hashes
:
list
[
str
],
mm_positions
:
list
[
"PlaceholderRange"
],
)
->
torch
.
Tensor
:
"""
Overwrite token_ids in-place for multimodal placeholders using
efficient slice assignments.
"""
n
=
token_ids
.
size
(
0
)
for
hash_str
,
placeholder
in
zip
(
mm_hashes
,
mm_positions
):
start
,
length
=
placeholder
.
offset
,
placeholder
.
length
if
start
>=
n
:
continue
end
=
min
(
start
+
length
,
n
)
token_ids
[
start
:
end
]
=
hex_hash_to_int16
(
hash_str
)
return
token_ids
def
mla_enabled
(
model_config
:
"ModelConfig"
)
->
bool
:
return
(
hasattr
(
model_config
,
"use_mla"
)
and
isinstance
(
model_config
.
use_mla
,
bool
)
and
model_config
.
use_mla
)
def
create_lmcache_metadata
(
vllm_config
=
None
,
model_config
=
None
,
parallel_config
=
None
,
cache_config
=
None
):
"""
Create LMCacheEngineMetadata from vLLM configuration.
This function extracts common metadata creation logic that was duplicated
across multiple files.
Args:
vllm_config (VllmConfig): vLLM configuration object containing model,
parallel, and cache configs (alternative to
individual config parameters)
model_config (ModelConfig): Model configuration (alternative to
vllm_config)
parallel_config (ParallelConfig): Parallel configuration (alternative
to vllm_config)
cache_config (CacheConfig): Cache configuration (alternative to
vllm_config)
"""
# Third Party
# First Party
from
lmcache.config
import
LMCacheEngineMetadata
from
vllm.utils
import
get_kv_cache_torch_dtype
config
=
lmcache_get_or_create_config
()
# Support both vllm_config object and individual config parameters
if
vllm_config
is
not
None
:
model_cfg
=
vllm_config
.
model_config
parallel_cfg
=
vllm_config
.
parallel_config
cache_cfg
=
vllm_config
.
cache_config
else
:
if
model_config
is
None
or
parallel_config
is
None
or
cache_config
is
None
:
raise
ValueError
(
"Either vllm_config must be provided, or all of "
"model_config, parallel_config, and cache_config must be provided."
)
model_cfg
=
model_config
parallel_cfg
=
parallel_config
cache_cfg
=
cache_config
# Get KV cache dtype
kv_dtype
=
get_kv_cache_torch_dtype
(
cache_cfg
.
cache_dtype
,
model_cfg
.
dtype
)
# Check if MLA is enabled
use_mla
=
mla_enabled
(
model_cfg
)
# Construct KV shape (for memory pool)
num_layer
=
model_cfg
.
get_num_layers
(
parallel_cfg
)
chunk_size
=
config
.
chunk_size
num_kv_head
=
model_cfg
.
get_num_kv_heads
(
parallel_cfg
)
head_size
=
model_cfg
.
get_head_size
()
kv_shape
=
(
num_layer
,
1
if
use_mla
else
2
,
chunk_size
,
num_kv_head
,
head_size
)
# Create metadata
metadata
=
LMCacheEngineMetadata
(
model_cfg
.
model
,
parallel_cfg
.
world_size
,
parallel_cfg
.
rank
,
"vllm"
,
kv_dtype
,
kv_shape
,
use_mla
,
)
return
metadata
,
config
def
extract_mm_features
(
request
:
Union
[
"Request"
,
"NewRequestData"
],
modify
:
bool
=
False
)
->
tuple
[
list
[
str
],
list
[
"PlaceholderRange"
]]:
"""
Normalize multimodal information from a Request into parallel lists.
This helper reads either:
1) `request.mm_features` (objects each exposing `.identifier` and
`.mm_position`), or
2) legacy fields `request.mm_hashes` and `request.mm_positions`.
It returns two equally sized lists: the multimodal hash identifiers and
their corresponding positions. If the request contains no multimodal info,
it returns `([], [])`.
Args:
request (Request): The source object.
modify (bool):
Controls copy semantics for the legacy-path return values.
- If True and legacy fields are used, shallow-copies are returned so
the caller can mutate the lists without affecting `request`.
- If False, the original legacy sequences are returned as-is
(zero-copy); treat them as read-only.
Returns:
tuple[list[str], list[PlaceholderRange]]: (`mm_hashes`, `mm_positions`).
May be `([], [])` when no multimodal data is present.
"""
if
getattr
(
request
,
"mm_features"
,
None
):
mm_hashes
,
mm_positions
=
zip
(
*
((
f
.
identifier
,
f
.
mm_position
)
for
f
in
request
.
mm_features
)
)
return
(
list
(
mm_hashes
),
list
(
mm_positions
))
elif
getattr
(
request
,
"mm_hashes"
,
None
):
if
modify
:
return
(
request
.
mm_hashes
.
copy
(),
# type: ignore
request
.
mm_positions
.
copy
(),
# type: ignore
)
else
:
return
(
request
.
mm_hashes
,
request
.
mm_positions
)
# type: ignore
else
:
return
([],
[])
vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py
0 → 100644
View file @
83f478bb
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Standard
import
os
import
uuid
from
collections.abc
import
Generator
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Any
,
Optional
import
torch
from
lmcache
import
utils
from
lmcache.config
import
LMCacheEngineMetadata
from
lmcache.logging
import
init_logger
from
lmcache.observability
import
LMCStatsMonitor
from
lmcache.utils
import
_lmcache_nvtx_annotate
from
lmcache.v1.cache_engine
import
LMCacheEngine
,
LMCacheEngineBuilder
from
lmcache.v1.compute.blend
import
LMCBlenderBuilder
from
lmcache.v1.config
import
LMCacheEngineConfig
,
_validate_and_set_config_value
from
lmcache.v1.gpu_connector
import
(
VLLMBufferLayerwiseGPUConnector
,
VLLMPagedMemGPUConnectorV2
,
VLLMPagedMemLayerwiseGPUConnector
,
)
from
lmcache.v1.internal_api_server.api_server
import
InternalAPIServer
from
lmcache.v1.lookup_client
import
LookupClientFactory
from
lmcache.v1.lookup_client.lmcache_async_lookup_client
import
(
LMCacheAsyncLookupServer
,
)
from
lmcache.v1.offload_server.zmq_server
import
ZMQOffloadServer
from
lmcache.v1.plugin.plugin_launcher
import
PluginLauncher
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
(
KVConnectorBase_V1
,
KVConnectorMetadata
,
KVConnectorRole
,
)
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_integration.utils
import
(
ENGINE_NAME
,
apply_mm_hashes_to_token_ids
,
extract_mm_features
,
lmcache_get_or_create_config
,
mla_enabled
,
)
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
,
get_tp_group
from
vllm.sampling_params
import
SamplingParams
from
vllm.utils
import
cdiv
,
get_kv_cache_torch_dtype
from
vllm.v1.core.sched.output
import
SchedulerOutput
from
vllm.version
import
__version__
as
VLLM_VERSION
if
TYPE_CHECKING
:
from
vllm.attention.backends.abstract
import
AttentionMetadata
from
vllm.forward_context
import
ForwardContext
from
vllm.multimodal.inputs
import
PlaceholderRange
from
vllm.v1.core.kv_cache_manager
import
KVCacheManager
from
vllm.v1.core.sched.output
import
NewRequestData
from
vllm.v1.request
import
Request
logger
=
init_logger
(
__name__
)
@
dataclass
class
LoadSpec
:
# Number of tokens cached in vLLM
vllm_cached_tokens
:
int
# Number of tokens that are cached in LMCache
lmcache_cached_tokens
:
int
# Whether the scheduler allow us to load the tokens
can_load
:
bool
@
dataclass
class
SaveSpec
:
# Skip already saved tokens
skip_leading_tokens
:
int
# Whether the scheduler allow us to save the tokens
can_save
:
bool
@
dataclass
class
DisaggSpec
:
req_id
:
str
receiver_id
:
str
receiver_host
:
str
receiver_init_port
:
int
receiver_alloc_port
:
int
is_last_prefill
:
bool
=
False
num_transferred_tokens
:
int
=
0
tmp_disagg_tracker
:
dict
[
str
,
DisaggSpec
]
=
{}
def
extract_request_configs
(
sampling_params
:
SamplingParams
)
->
dict
|
None
:
request_configs
=
None
if
(
sampling_params
.
extra_args
is
not
None
and
"kv_transfer_params"
in
sampling_params
.
extra_args
):
kv_transfer_params
=
sampling_params
.
extra_args
.
get
(
"kv_transfer_params"
)
if
kv_transfer_params
is
None
:
return
None
assert
isinstance
(
kv_transfer_params
,
dict
)
for
k
,
v
in
kv_transfer_params
.
items
():
if
k
.
startswith
(
"lmcache."
):
if
request_configs
is
None
:
request_configs
=
{}
request_configs
[
k
]
=
v
return
request_configs
@
dataclass
class
RequestTracker
:
# Request id
req_id
:
str
# Total prompt token length
prompt_len
:
int
# The token ids that has been scheduled so far
token_ids
:
list
[
int
]
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
allocated_block_ids
:
list
[
int
]
# The number of tokens that has been saved
num_saved_tokens
:
int
=
0
# Disagg spec for the request
disagg_spec
:
DisaggSpec
|
None
=
None
# Multimodal hashes and positions
mm_hashes
:
list
[
str
]
|
None
=
None
mm_positions
:
list
[
"PlaceholderRange"
]
|
None
=
None
# The configs of the request, includes tags and other configs
request_configs
:
dict
|
None
=
None
# Whether the request is in decode phase
is_decode_phase
=
False
# Whether the request cache should be saved
skip_save
:
bool
=
False
@
_lmcache_nvtx_annotate
@
staticmethod
def
from_new_request
(
lmcache_config
:
LMCacheEngineConfig
,
new_request
:
"NewRequestData"
,
num_tokens_to_compute
:
int
,
lmcache_cached_tokens
:
int
,
skip_save
:
bool
,
)
->
"RequestTracker"
:
"""Create the request tracker from a new request.
Args:
lmcache_config (LMCacheEngineConfig): the LMCache engine config.
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
lmcache_cached_tokens (int): the number of tokens that are
cached in LMCache.
skip_save (bool): whether the request cache should be saved
"""
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
# list[list[int]]
# Need to check the type of request.block_ids
unfolded_block_ids
=
[]
if
not
isinstance
(
new_request
.
block_ids
[
0
],
list
):
unfolded_block_ids
=
new_request
.
block_ids
.
copy
()
else
:
# According to the vLLM code
# (https://github.com/vllm-project/vllm/blob/main/vllm/v1/core/
# sched/scheduler.py#L943),
# only one KVCacheGroup is supported in connector for now.
unfolded_block_ids
=
new_request
.
block_ids
[
0
].
copy
()
# NOTE: Initialized in `update_state_after_alloc`
disagg_spec
=
tmp_disagg_tracker
.
pop
(
new_request
.
req_id
,
None
)
if
new_request
.
sampling_params
:
request_configs
=
extract_request_configs
(
new_request
.
sampling_params
)
else
:
request_configs
=
None
mm_hashes
,
mm_positions
=
extract_mm_features
(
new_request
,
modify
=
True
)
assert
new_request
.
prompt_token_ids
is
not
None
return
RequestTracker
(
req_id
=
new_request
.
req_id
,
prompt_len
=
len
(
new_request
.
prompt_token_ids
),
token_ids
=
new_request
.
prompt_token_ids
[:
num_tokens_to_compute
].
copy
(),
allocated_block_ids
=
unfolded_block_ids
,
num_saved_tokens
=
lmcache_cached_tokens
,
disagg_spec
=
disagg_spec
,
mm_hashes
=
mm_hashes
,
mm_positions
=
mm_positions
,
skip_save
=
skip_save
,
request_configs
=
request_configs
,
)
def
update
(
self
,
new_token_ids
:
list
[
int
],
new_block_ids
:
tuple
[
list
[
int
],
...]
|
None
|
list
[
int
],
)
->
None
:
"""Update the request tracker when a running request is
scheduled again
"""
self
.
token_ids
.
extend
(
new_token_ids
)
if
new_block_ids
is
None
:
# https://github.com/vllm-project/vllm/commit/
# b029de9902aa3ac58806c8c17776c7074175b6db
new_block_ids
=
[]
elif
len
(
new_block_ids
)
==
0
:
new_block_ids
=
[]
elif
isinstance
(
new_block_ids
,
tuple
):
new_block_ids
=
new_block_ids
[
0
]
elif
isinstance
(
new_block_ids
,
list
):
pass
else
:
raise
ValueError
(
f
"Unsupported new_block_ids type
{
type
(
new_block_ids
)
}
"
)
self
.
allocated_block_ids
.
extend
(
new_block_ids
)
# When a request is scheduled again, and the number of new tokens
# is 1 (excluding chunked prefill), the request is in decode phase.
if
len
(
new_token_ids
)
==
1
:
self
.
is_decode_phase
=
True
@
dataclass
class
ReqMeta
:
# Request id
req_id
:
str
# Request tokens
token_ids
:
list
[
int
]
# torch.Tensor
# Slot mapping
slot_mapping
:
torch
.
Tensor
# Whether is last prefill or not
is_last_prefill
:
bool
=
False
# Skip save or not
save_spec
:
SaveSpec
|
None
=
None
# load_spec
load_spec
:
LoadSpec
|
None
=
None
# disagg spec
disagg_spec
:
DisaggSpec
|
None
=
None
# the configs of the request
request_configs
:
dict
|
None
=
None
@
staticmethod
def
from_request_tracker
(
tracker
:
RequestTracker
,
block_size
:
int
,
lmcache_chunk_size
:
int
=
256
,
load_spec
:
LoadSpec
|
None
=
None
,
discard_partial_chunks
:
bool
=
True
,
save_decode_cache
:
bool
=
False
,
)
->
Optional
[
"ReqMeta"
]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
lmcache_chunk_size (int): the chunk size for LMCache.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
discard_partial_chunks (bool): whether to discard partial chunks.
save_decode_cache (bool): whether to save the cache in decode phase.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_ids
=
tracker
.
token_ids
input_token_len
=
len
(
input_token_ids
)
is_last_prefill
=
False
if
input_token_len
==
tracker
.
prompt_len
:
is_last_prefill
=
True
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
# 3. if save_decode_cache is False and it is in decode phase
skip_leading_tokens
=
tracker
.
num_saved_tokens
chunk_boundary
=
(
cdiv
(
tracker
.
num_saved_tokens
+
1
,
lmcache_chunk_size
)
*
lmcache_chunk_size
)
# NOTE(vladnosiv): for disagg, you cannot skip saving, as saving is a
# trqansfer. Check if request_configs has lmcache.skip_save set to True
request_skip
=
(
tracker
.
request_configs
or
{}).
get
(
"lmcache.skip_save"
,
False
)
skip_save
=
tracker
.
disagg_spec
is
None
and
(
tracker
.
skip_save
or
(
tracker
.
num_saved_tokens
>
0
and
input_token_len
<
chunk_boundary
)
or
(
tracker
.
is_decode_phase
and
not
save_decode_cache
)
or
request_skip
)
if
skip_save
and
load_spec
is
None
:
return
None
# Calculate number of tokens to save based on discard_partial_chunks
# setting
# NOTE(vladnosiv): for the input_token_len chunk prefill,
# we are required to discard partial chunks,
# as new tokens will be added in the next iteration.
num_tokens_to_save
=
(
(
input_token_len
//
lmcache_chunk_size
*
lmcache_chunk_size
)
if
not
is_last_prefill
or
discard_partial_chunks
else
input_token_len
)
# If we need to save, update the number of saved tokens
if
not
skip_save
:
tracker
.
num_saved_tokens
=
num_tokens_to_save
save_spec
=
SaveSpec
(
skip_leading_tokens
,
not
skip_save
)
# Calculate the token ids and slot mappings for load and save
token_ids
=
input_token_ids
[:
num_tokens_to_save
]
# If the request has multimodal hashes, apply them to the token ids
if
tracker
.
mm_hashes
:
token_ids_tensor
=
torch
.
tensor
(
token_ids
)
assert
tracker
.
mm_positions
is
not
None
,
(
"tracker got mm_hashes but no mm_positions"
)
apply_mm_hashes_to_token_ids
(
token_ids_tensor
,
tracker
.
mm_hashes
,
tracker
.
mm_positions
)
token_ids
=
token_ids_tensor
.
tolist
()
num_blocks
=
len
(
tracker
.
allocated_block_ids
)
if
len
(
token_ids
)
>
num_blocks
*
block_size
:
logger
.
error
(
"The number of tokens is more than the number of blocks."
"Something might be wrong in scheduling logic!"
)
logger
.
error
(
"Num tokens: %d, num blocks: %d, block size: %d"
,
len
(
token_ids
),
num_blocks
,
block_size
,
)
block_ids
=
torch
.
tensor
(
tracker
.
allocated_block_ids
,
dtype
=
torch
.
long
)
block_offsets
=
torch
.
arange
(
0
,
block_size
,
dtype
=
torch
.
long
)
slot_mapping
=
(
block_offsets
.
reshape
((
1
,
block_size
))
+
block_ids
.
reshape
((
num_blocks
,
1
))
*
block_size
)
slot_mapping
=
slot_mapping
.
flatten
()[:
len
(
token_ids
)]
assert
slot_mapping
.
dtype
==
torch
.
long
# For load operation: check whether the request is scheduled to load
if
load_spec
is
not
None
and
load_spec
.
can_load
:
logger
.
debug
(
"Scheduled to load %d tokens for request %s"
,
load_spec
.
lmcache_cached_tokens
,
tracker
.
req_id
,
)
else
:
# Do not load if not in `can_load` state
load_spec
=
None
return
ReqMeta
(
req_id
=
tracker
.
req_id
,
token_ids
=
token_ids
,
slot_mapping
=
slot_mapping
,
is_last_prefill
=
is_last_prefill
,
save_spec
=
save_spec
,
load_spec
=
load_spec
,
disagg_spec
=
tracker
.
disagg_spec
,
request_configs
=
tracker
.
request_configs
,
)
def
need_gpu_interm_buffer
(
lmcache_config
:
LMCacheEngineConfig
):
return
lmcache_config
.
enable_pd
def
_calculate_mtp_layers
(
vllm_config
,
model_config
):
num_mtp_layers
=
0
if
vllm_config
is
not
None
and
vllm_config
.
speculative_config
is
not
None
:
logger
.
info
(
"vllm_config.speculative_config: %s"
,
vllm_config
.
speculative_config
)
# TODO(baoloongmao): Support other MTP methods
if
vllm_config
.
speculative_config
.
method
==
"deepseek_mtp"
:
num_mtp_layers
=
getattr
(
model_config
.
hf_config
,
"num_nextn_predict_layers"
,
0
)
return
num_mtp_layers
def
_init_lmcache_engine
(
lmcache_config
:
LMCacheEngineConfig
,
vllm_config
:
"VllmConfig"
,
)
->
LMCacheEngine
:
"""Initialize the LMCache engine by the given model config and parallel
config. This function will check the environment variable
`LMCACHE_CONFIG_FILE` to load the configuration file. If that environment
variable is not set, this function will return None.
:param lmcache_config: The LMCache configuration.
:type lmcache_config: LMCacheEngineConfig
:param vllm_config: The vLLM configuration.
:type vllm_config: VllmConfig
:return: The initialized LMCache engine
:rtype: LMCacheEngine
"""
if
curr_engine
:
=
LMCacheEngineBuilder
.
get
(
ENGINE_NAME
):
return
curr_engine
model_config
=
vllm_config
.
model_config
parallel_config
=
vllm_config
.
parallel_config
cache_config
=
vllm_config
.
cache_config
assert
isinstance
(
lmcache_config
,
LMCacheEngineConfig
),
(
"LMCache v1 configuration is should be passed."
)
kv_dtype
=
get_kv_cache_torch_dtype
(
cache_config
.
cache_dtype
,
model_config
.
dtype
)
use_mla
=
mla_enabled
(
model_config
)
if
use_mla
and
(
lmcache_config
.
remote_serde
!=
"naive"
and
lmcache_config
.
remote_serde
is
not
None
):
raise
ValueError
(
"MLA only works with naive serde mode.."
)
# construct kv shape (for mem pool)
num_layer
=
model_config
.
get_num_layers
(
parallel_config
)
num_mtp_layers
=
_calculate_mtp_layers
(
vllm_config
,
model_config
)
num_layer
+=
num_mtp_layers
chunk_size
=
lmcache_config
.
chunk_size
num_kv_head
=
model_config
.
get_num_kv_heads
(
parallel_config
)
head_size
=
model_config
.
get_head_size
()
kv_shape
=
(
num_layer
,
1
if
use_mla
else
2
,
chunk_size
,
num_kv_head
,
head_size
)
logger
.
info
(
"use mla: %s, kv shape: %s, num_mtp_layers: %s"
,
use_mla
,
kv_shape
,
num_mtp_layers
,
)
# Change current device.
num_gpus
=
torch
.
cuda
.
device_count
()
local_rank
=
parallel_config
.
rank
%
num_gpus
torch
.
cuda
.
set_device
(
local_rank
)
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
metadata
=
LMCacheEngineMetadata
(
model_config
.
model
,
parallel_config
.
world_size
,
parallel_config
.
rank
,
"vllm"
,
kv_dtype
,
kv_shape
,
use_mla
,
)
use_gpu
=
need_gpu_interm_buffer
(
lmcache_config
)
vllm_gpu_connector
:
(
VLLMBufferLayerwiseGPUConnector
|
VLLMPagedMemGPUConnectorV2
|
VLLMPagedMemLayerwiseGPUConnector
)
if
use_mla
and
lmcache_config
.
use_layerwise
:
raise
ValueError
(
"layerwise MLA connector is not supported yet"
)
# When use_mla is True, num_kv_head is 1
hidden_dim_size
=
num_kv_head
*
head_size
if
lmcache_config
.
use_layerwise
:
if
lmcache_config
.
enable_blending
:
# Use layerwise connector for blending
vllm_gpu_connector
=
VLLMBufferLayerwiseGPUConnector
(
hidden_dim_size
,
num_layer
,
use_gpu
=
use_gpu
,
chunk_size
=
chunk_size
,
dtype
=
kv_dtype
,
device
=
device
,
)
else
:
vllm_gpu_connector
=
VLLMPagedMemLayerwiseGPUConnector
(
hidden_dim_size
,
num_layer
,
use_gpu
=
use_gpu
,
chunk_size
=
chunk_size
,
dtype
=
kv_dtype
,
device
=
device
,
)
else
:
vllm_gpu_connector
=
VLLMPagedMemGPUConnectorV2
(
hidden_dim_size
,
num_layer
,
use_gpu
=
use_gpu
,
chunk_size
=
chunk_size
,
dtype
=
kv_dtype
,
device
=
device
,
use_mla
=
use_mla
,
)
tpg
=
get_tp_group
()
engine
=
LMCacheEngineBuilder
.
get_or_create
(
ENGINE_NAME
,
lmcache_config
,
metadata
,
vllm_gpu_connector
,
tpg
.
broadcast
,
tpg
.
broadcast_object
,
)
return
engine
@
dataclass
class
LMCacheConnectorMetadata
(
KVConnectorMetadata
):
requests
:
list
[
ReqMeta
]
=
field
(
default_factory
=
list
)
lookup_requests_in_step
:
list
[
str
]
=
field
(
default_factory
=
list
)
@
_lmcache_nvtx_annotate
def
add_request
(
self
,
req_meta
:
ReqMeta
)
->
None
:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self
.
requests
.
append
(
req_meta
)
class
LMCacheConnectorV1Impl
:
def
__init__
(
self
,
vllm_config
:
"VllmConfig"
,
role
:
KVConnectorRole
,
parent
:
KVConnectorBase_V1
,
):
assert
vllm_config
.
kv_transfer_config
is
not
None
self
.
_parent
=
parent
self
.
_vllm_config
=
vllm_config
self
.
kv_role
=
vllm_config
.
kv_transfer_config
.
kv_role
self
.
worker_count
=
vllm_config
.
parallel_config
.
tensor_parallel_size
config
=
lmcache_get_or_create_config
()
assert
isinstance
(
config
,
LMCacheEngineConfig
),
(
"LMCache v1 configuration is should be passed for vLLM v1."
)
# Put the leading with "lmcache." and matched configs from
# vllm extra_config to the config
kv_connector_extra_config
=
(
vllm_config
.
kv_transfer_config
.
kv_connector_extra_config
)
if
kv_connector_extra_config
:
for
key
,
value
in
kv_connector_extra_config
.
items
():
if
key
.
startswith
(
"lmcache."
):
config_key
=
key
[
8
:]
# Remove "lmcache." prefix
if
_validate_and_set_config_value
(
config
,
config_key
,
value
):
logger
.
info
(
"Updated config %s from vLLM extra config: %s"
,
config_key
,
value
,
)
self
.
config
=
config
self
.
async_loading
=
config
.
enable_async_loading
self
.
layerwise_retrievers
:
list
[
Generator
[
torch
.
Tensor
|
None
,
None
,
None
]]
=
[]
self
.
_stats_monitor
=
LMCStatsMonitor
.
GetOrCreate
()
if
role
==
KVConnectorRole
.
SCHEDULER
:
# Create lookup client using factory
self
.
lookup_client
=
LookupClientFactory
.
create_lookup_client
(
vllm_config
,
config
)
self
.
_unfinished_requests
:
dict
[
str
,
Request
]
=
{}
self
.
_lookup_requests_in_step
:
list
[
str
]
=
[]
self
.
lmcache_engine
=
None
else
:
self
.
lmcache_engine
=
_init_lmcache_engine
(
config
,
vllm_config
,
)
self
.
use_layerwise
=
config
.
use_layerwise
self
.
enable_blending
=
config
.
enable_blending
if
self
.
enable_blending
:
self
.
blender
=
LMCBlenderBuilder
.
get_or_create
(
ENGINE_NAME
,
self
.
lmcache_engine
,
self
.
lmcache_engine
.
gpu_connector
,
config
,
)
# Create lookup server using factory
assert
self
.
lmcache_engine
is
not
None
self
.
lookup_server
=
LookupClientFactory
.
create_lookup_server
(
self
.
lmcache_engine
,
vllm_config
)
self
.
offload_server
=
ZMQOffloadServer
(
self
.
lmcache_engine
,
vllm_config
,
get_tensor_model_parallel_rank
(),
)
# In case of MLA, the lookup server is only created on worker 0
if
self
.
async_loading
and
self
.
lookup_server
is
not
None
:
assert
isinstance
(
self
.
lookup_server
,
LMCacheAsyncLookupServer
)
self
.
lmcache_engine
.
post_init
(
async_lookup_server
=
self
.
lookup_server
)
self
.
kv_caches
:
dict
[
str
,
torch
.
Tensor
]
=
{}
self
.
_block_size
=
vllm_config
.
cache_config
.
block_size
# request_id -> (vllm cached tokens, lmcache cached tokens)
self
.
load_specs
:
dict
[
str
,
LoadSpec
]
=
{}
self
.
kv_cache_manager
:
KVCacheManager
|
None
=
None
# request_id -> full_token_ids
self
.
_request_trackers
:
dict
[
str
,
RequestTracker
]
=
{}
# Whether to discard partial chunks
self
.
_discard_partial_chunks
=
(
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"discard_partial_chunks"
,
False
)
or
not
config
.
save_unfull_chunk
)
self
.
_lmcache_chunk_size
=
config
.
chunk_size
self
.
_save_decode_cache
=
config
.
save_decode_cache
self
.
skip_last_n_tokens
=
vllm_config
.
kv_transfer_config
.
get_from_extra_config
(
"skip_last_n_tokens"
,
0
)
self
.
num_layers
=
vllm_config
.
model_config
.
get_num_layers
(
vllm_config
.
parallel_config
)
self
.
current_layer
=
0
self
.
force_skip_save
=
bool
(
os
.
environ
.
get
(
"LMCACHE_FORCE_SKIP_SAVE"
,
False
))
self
.
_requests_priority
:
dict
[
str
,
int
]
=
{}
# TODO(baoloongmao): Internal api server & plugin framework support
# dp > 1
if
(
vllm_config
.
parallel_config
.
data_parallel_size_local
==
1
or
vllm_config
.
parallel_config
.
data_parallel_rank_local
==
0
):
# Start internal API server if enabled
# The enabled check is in the InternalAPIServer constructor
self
.
api_server
=
InternalAPIServer
(
self
)
self
.
api_server
.
start
()
# Launch plugins
self
.
plugin_launcher
=
PluginLauncher
(
self
.
config
,
role
,
self
.
worker_count
,
-
1
if
self
.
lmcache_engine
is
None
# scheduler side
else
self
.
lmcache_engine
.
metadata
.
worker_id
,
)
self
.
plugin_launcher
.
launch_plugins
()
else
:
self
.
api_server
=
None
# type: ignore[assignment]
self
.
plugin_launcher
=
None
# type: ignore[assignment]
logger
.
info
(
"LMCache initialized for role %s with version %s, "
"vllm version %s, lmcache cache_engine metadata: %s"
,
role
,
utils
.
get_version
(),
VLLM_VERSION
,
getattr
(
self
.
lmcache_engine
,
"metadata"
,
None
),
)
def
get_inference_info
(
self
)
->
dict
:
"""Get inference information including vLLM config and related details.
Returns:
dict: Dictionary containing inference information
"""
# Get vLLM config information
vllm_config
=
self
.
_vllm_config
# Use vLLM config's string representation and add specific configs
inference_info
=
{
"vllm_version"
:
VLLM_VERSION
,
"lmcache_version"
:
utils
.
get_version
(),
"vllm_config"
:
str
(
vllm_config
),
"model_config"
:
{
"model"
:
getattr
(
vllm_config
.
model_config
,
"model"
,
None
),
"dtype"
:
str
(
getattr
(
vllm_config
.
model_config
,
"dtype"
,
None
)),
"max_model_len"
:
getattr
(
vllm_config
.
model_config
,
"max_model_len"
,
None
),
"vocab_size"
:
getattr
(
vllm_config
.
model_config
,
"vocab_size"
,
None
),
"num_layers"
:
getattr
(
vllm_config
.
model_config
,
"get_num_layers"
,
lambda
_
:
None
)(
vllm_config
.
parallel_config
),
"num_attention_heads"
:
getattr
(
vllm_config
.
model_config
,
"get_num_attention_heads"
,
lambda
_
:
None
)(
vllm_config
.
parallel_config
),
"num_kv_heads"
:
getattr
(
vllm_config
.
model_config
,
"get_num_kv_heads"
,
lambda
_
:
None
)(
vllm_config
.
parallel_config
),
"head_size"
:
getattr
(
vllm_config
.
model_config
,
"get_head_size"
,
lambda
:
None
)(),
},
"cache_config"
:
{
"block_size"
:
getattr
(
vllm_config
.
cache_config
,
"block_size"
,
None
),
"cache_dtype"
:
str
(
getattr
(
vllm_config
.
cache_config
,
"cache_dtype"
,
None
)
),
"gpu_memory_utilization"
:
getattr
(
vllm_config
.
cache_config
,
"gpu_memory_utilization"
,
None
),
"swap_space"
:
getattr
(
vllm_config
.
cache_config
,
"swap_space"
,
None
),
"enable_prefix_caching"
:
getattr
(
vllm_config
.
cache_config
,
"enable_prefix_caching"
,
None
),
},
}
return
inference_info
def
get_inference_version
(
self
)
->
str
:
"""Get vLLM version information.
Returns:
str: vLLM version string
"""
return
VLLM_VERSION
@
_lmcache_nvtx_annotate
def
_init_kv_caches_from_forward_context
(
self
,
forward_context
:
"ForwardContext"
):
for
layer_name
in
forward_context
.
no_compile_layers
:
attn_layer
=
forward_context
.
no_compile_layers
[
layer_name
]
if
not
hasattr
(
attn_layer
,
"kv_cache"
):
logger
.
debug
(
"The layer %s does not have kv_cache, skip it"
,
layer_name
)
continue
if
layer_name
not
in
self
.
kv_caches
:
self
.
kv_caches
[
layer_name
]
=
attn_layer
.
kv_cache
[
forward_context
.
virtual_engine
]
####################
# Worker side APIs
####################
@
_lmcache_nvtx_annotate
def
start_load_kv
(
self
,
forward_context
:
"ForwardContext"
,
**
kwargs
)
->
None
:
"""Start loading the KV cache from the connector buffer to vLLM's
paged KV buffer.
Args:
forward_context (ForwardContext): the forward context.
Note:
The number of elements in kv_caches and layer_names should be
the same.
"""
self
.
current_layer
=
0
if
len
(
self
.
kv_caches
)
==
0
:
self
.
_init_kv_caches_from_forward_context
(
forward_context
)
metadata
=
self
.
_parent
.
_get_connector_metadata
()
assert
isinstance
(
metadata
,
LMCacheConnectorMetadata
)
assert
len
(
self
.
kv_caches
)
>
0
kvcaches
=
list
(
self
.
kv_caches
.
values
())
attn_metadata
=
forward_context
.
attn_metadata
if
attn_metadata
is
None
:
logger
.
debug
(
"In connector.start_load_kv, but the attn_metadata is None"
)
return
assert
self
.
lmcache_engine
is
not
None
self
.
lmcache_engine
.
post_init
(
kvcaches
=
kvcaches
)
self
.
layerwise_retrievers
=
[]
for
idx
,
request
in
enumerate
(
metadata
.
requests
):
if
request
.
load_spec
is
None
:
continue
last_idx
=
idx
for
idx
,
request
in
enumerate
(
metadata
.
requests
):
if
request
.
load_spec
is
None
:
continue
tokens
=
request
.
token_ids
# TODO: have a pre-allocated buffer to hold the slot_mappings
slot_mapping
=
request
.
slot_mapping
.
cuda
()
assert
len
(
tokens
)
==
len
(
slot_mapping
)
self
.
_stats_monitor
.
update_interval_vllm_hit_tokens
(
request
.
load_spec
.
vllm_cached_tokens
)
token_mask
=
torch
.
ones
(
len
(
tokens
),
dtype
=
torch
.
bool
)
masked_token_count
=
(
request
.
load_spec
.
vllm_cached_tokens
//
self
.
_lmcache_chunk_size
*
self
.
_lmcache_chunk_size
)
token_mask
[:
masked_token_count
]
=
False
lmcache_cached_tokens
=
request
.
load_spec
.
lmcache_cached_tokens
if
self
.
use_layerwise
:
sync
=
idx
==
last_idx
# NOTE(Jiayi): Perform blending before layerwise prefix caching
if
self
.
enable_blending
:
# TODO(Jiayi): Need to make prefix caching and blending
# compatible
self
.
blender
.
blend
(
tokens
[:
lmcache_cached_tokens
],
token_mask
[:
lmcache_cached_tokens
],
kvcaches
=
kvcaches
,
slot_mapping
=
slot_mapping
[:
lmcache_cached_tokens
],
)
else
:
layerwise_retriever
=
self
.
lmcache_engine
.
retrieve_layer
(
tokens
[:
lmcache_cached_tokens
],
token_mask
[:
lmcache_cached_tokens
],
kvcaches
=
kvcaches
,
slot_mapping
=
slot_mapping
[:
lmcache_cached_tokens
],
sync
=
sync
,
)
# NOTE: retrieve for two layers at the first layer
next
(
layerwise_retriever
)
next
(
layerwise_retriever
)
self
.
layerwise_retrievers
.
append
(
layerwise_retriever
)
else
:
ret_token_mask
=
self
.
lmcache_engine
.
retrieve
(
tokens
[:
lmcache_cached_tokens
],
token_mask
[:
lmcache_cached_tokens
],
kvcaches
=
kvcaches
,
slot_mapping
=
slot_mapping
[:
lmcache_cached_tokens
],
request_configs
=
request
.
request_configs
,
req_id
=
request
.
req_id
,
)
# Check the result
num_retrieved_tokens
=
ret_token_mask
.
sum
().
item
()
num_expected_tokens
=
(
lmcache_cached_tokens
-
request
.
load_spec
.
vllm_cached_tokens
)
if
num_retrieved_tokens
<
num_expected_tokens
:
logger
.
error
(
"The number of retrieved tokens is less than the "
"expected number of tokens! This should not happen!"
)
logger
.
error
(
"Num retrieved tokens: %d, num expected tokens: %d"
,
num_retrieved_tokens
,
num_expected_tokens
,
)
@
_lmcache_nvtx_annotate
def
wait_for_layer_load
(
self
,
layer_name
:
str
)
->
None
:
"""Blocking until the KV for a specific layer is loaded into vLLM's
paged buffer.
This interface will be useful for layer-by-layer pipelining.
Args:
layer_name: the name of that layer
"""
if
self
.
layerwise_retrievers
:
logger
.
debug
(
"Waiting for layer %s to be loaded"
,
self
.
current_layer
)
# Wait for the layer to be loaded
for
layerwise_retriever
in
self
.
layerwise_retrievers
:
ret_token_mask
=
next
(
layerwise_retriever
)
if
self
.
current_layer
==
self
.
num_layers
-
1
:
assert
ret_token_mask
is
not
None
num_retrieved_tokens
=
ret_token_mask
.
sum
().
item
()
logger
.
info
(
"Retrieved %s tokens"
,
num_retrieved_tokens
)
return
@
_lmcache_nvtx_annotate
def
save_kv_layer
(
self
,
layer_name
:
str
,
kv_layer
:
torch
.
Tensor
,
attn_metadata
:
"AttentionMetadata"
,
**
kwargs
,
)
->
None
:
"""Start saving the a layer of KV cache from vLLM's paged buffer
to the connector.
Args:
layer_name (str): the name of the layer.
kv_layer (torch.Tensor): the paged KV buffer of the current
layer in vLLM.
attn_metadata (AttentionMetadata): the attention metadata.
"""
assert
self
.
lmcache_engine
is
not
None
if
not
self
.
use_layerwise
:
return
if
self
.
kv_role
==
"kv_consumer"
:
# Don't do save if the role is kv_consumer
return
if
self
.
_parent
.
_connector_metadata
is
None
:
logger
.
warning
(
"In connector.save_kv_layer, but the connector metadata is None"
)
return
connector_metadata
=
self
.
_parent
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
LMCacheConnectorMetadata
)
assert
len
(
self
.
kv_caches
)
>
0
kvcaches
=
list
(
self
.
kv_caches
.
values
())
if
self
.
current_layer
==
0
:
self
.
layerwise_storers
=
[]
is_first
=
True
for
idx
,
request
in
enumerate
(
connector_metadata
.
requests
):
save_spec
=
request
.
save_spec
if
save_spec
is
None
or
not
save_spec
.
can_save
:
continue
token_ids
=
request
.
token_ids
assert
isinstance
(
token_ids
,
list
)
slot_mapping
=
request
.
slot_mapping
assert
isinstance
(
slot_mapping
,
torch
.
Tensor
)
assert
len
(
slot_mapping
)
==
len
(
token_ids
)
# TODO: have a pre-allocated buffer to hold the slot_mappings
slot_mapping
=
slot_mapping
.
cuda
()
if
self
.
kv_role
==
"kv_producer"
:
skip_leading_tokens
=
0
else
:
skip_leading_tokens
=
save_spec
.
skip_leading_tokens
if
skip_leading_tokens
==
len
(
token_ids
):
continue
# skip this request
# Align to lmcache chunk size
skip_leading_tokens
=
(
skip_leading_tokens
//
self
.
_lmcache_chunk_size
*
self
.
_lmcache_chunk_size
)
store_mask
=
torch
.
ones
(
len
(
token_ids
),
dtype
=
torch
.
bool
)
store_mask
[:
skip_leading_tokens
]
=
False
logger
.
info
(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s"
,
len
(
token_ids
)
-
skip_leading_tokens
,
len
(
token_ids
),
skip_leading_tokens
,
request
.
req_id
,
)
# TODO (Jiayi): need to make layerwise storing
# compatible with disagg spec
layerwise_storer
=
self
.
lmcache_engine
.
store_layer
(
token_ids
,
mask
=
store_mask
,
kvcaches
=
kvcaches
,
slot_mapping
=
slot_mapping
,
offset
=
skip_leading_tokens
,
sync
=
is_first
,
)
self
.
layerwise_storers
.
append
(
layerwise_storer
)
if
is_first
:
is_first
=
False
for
layerwise_storer
in
self
.
layerwise_storers
:
next
(
layerwise_storer
)
self
.
current_layer
+=
1
@
_lmcache_nvtx_annotate
def
wait_for_save
(
self
):
"""Blocking until the KV cache is saved to the connector buffer."""
connector_metadata
=
self
.
_parent
.
_get_connector_metadata
()
assert
isinstance
(
connector_metadata
,
LMCacheConnectorMetadata
)
self
.
lmcache_engine
.
lookup_unpin
(
# type: ignore
connector_metadata
.
lookup_requests_in_step
)
if
self
.
kv_role
==
"kv_consumer"
:
# Don't do save if the role is kv_consumer
return
if
self
.
use_layerwise
:
for
layerwise_storer
in
self
.
layerwise_storers
:
next
(
layerwise_storer
)
return
assert
len
(
self
.
kv_caches
)
>
0
kvcaches
=
list
(
self
.
kv_caches
.
values
())
assert
self
.
lmcache_engine
is
not
None
for
request
in
connector_metadata
.
requests
:
save_spec
=
request
.
save_spec
if
(
save_spec
is
None
or
not
save_spec
.
can_save
)
and
self
.
kv_role
!=
"kv_producer"
:
continue
token_ids
=
request
.
token_ids
slot_mapping
=
request
.
slot_mapping
assert
isinstance
(
slot_mapping
,
torch
.
Tensor
)
assert
len
(
slot_mapping
)
==
len
(
token_ids
)
assert
save_spec
is
not
None
# TODO: have a pre-allocated buffer to hold the slot_mappings
slot_mapping
=
slot_mapping
.
cuda
()
skip_leading_tokens
=
save_spec
.
skip_leading_tokens
if
self
.
kv_role
==
"kv_producer"
:
assert
request
.
disagg_spec
is
not
None
skip_leading_tokens
=
min
(
skip_leading_tokens
,
request
.
disagg_spec
.
num_transferred_tokens
)
if
skip_leading_tokens
==
len
(
token_ids
):
continue
# skip this request
# Align to lmcache chunk size
skip_leading_tokens
=
(
skip_leading_tokens
//
self
.
_lmcache_chunk_size
*
self
.
_lmcache_chunk_size
)
store_mask
=
torch
.
ones
(
len
(
token_ids
),
dtype
=
torch
.
bool
)
store_mask
[:
skip_leading_tokens
]
=
False
logger
.
info
(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s"
,
len
(
token_ids
)
-
skip_leading_tokens
,
len
(
token_ids
),
skip_leading_tokens
,
request
.
req_id
,
)
is_last_prefill
=
request
.
is_last_prefill
if
is_last_prefill
:
if
request
.
disagg_spec
:
request
.
disagg_spec
.
is_last_prefill
=
True
else
:
token_len
=
len
(
token_ids
)
aligned_token_len
=
(
token_len
//
self
.
_lmcache_chunk_size
*
self
.
_lmcache_chunk_size
)
token_ids
=
token_ids
[:
aligned_token_len
]
store_mask
=
store_mask
[:
aligned_token_len
]
slot_mapping
=
slot_mapping
[:
aligned_token_len
]
self
.
lmcache_engine
.
store
(
token_ids
,
mask
=
store_mask
,
kvcaches
=
kvcaches
,
slot_mapping
=
slot_mapping
,
offset
=
skip_leading_tokens
,
transfer_spec
=
request
.
disagg_spec
,
request_configs
=
request
.
request_configs
,
)
# NOTE(Jiayi): We assume all tokens are saved
save_spec
.
skip_leading_tokens
=
len
(
token_ids
)
if
request
.
disagg_spec
:
request
.
disagg_spec
.
num_transferred_tokens
=
len
(
token_ids
)
@
_lmcache_nvtx_annotate
def
get_finished
(
self
,
finished_req_ids
:
set
[
str
]
)
->
tuple
[
set
[
str
]
|
None
,
set
[
str
]
|
None
]:
return
None
,
None
###################
# Scheduler side APIs
####################
@
_lmcache_nvtx_annotate
def
get_num_new_matched_tokens
(
self
,
request
:
"Request"
,
num_computed_tokens
:
int
,
)
->
int
|
None
:
"""
Check for external KV cache hit.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
"""
if
self
.
kv_role
==
"kv_producer"
and
not
hasattr
(
self
.
lookup_client
,
"supports_producer_reuse"
):
return
0
self
.
_requests_priority
[
request
.
request_id
]
=
request
.
priority
token_ids
=
request
.
prompt_token_ids
# If the request has multimodal hashes, apply them to the token ids
mm_hashes
,
mm_positions
=
extract_mm_features
(
request
)
if
mm_hashes
and
mm_positions
:
# TODO(Jiayi): Optimize this
token_ids_tensor
=
torch
.
tensor
(
request
.
prompt_token_ids
)
apply_mm_hashes_to_token_ids
(
token_ids_tensor
,
mm_hashes
,
mm_positions
)
token_ids
=
token_ids_tensor
.
tolist
()
if
request
.
sampling_params
:
request_configs
=
extract_request_configs
(
request
.
sampling_params
)
else
:
request_configs
=
None
if
self
.
skip_last_n_tokens
>
0
:
assert
token_ids
is
not
None
token_ids
=
token_ids
[:
-
self
.
skip_last_n_tokens
]
lookup_id
=
request
.
request_id
if
self
.
async_loading
else
str
(
uuid
.
uuid4
())
self
.
_lookup_requests_in_step
.
append
(
lookup_id
)
num_external_hit_tokens
=
self
.
lookup_client
.
lookup
(
token_ids
,
lookup_id
=
lookup_id
,
request_configs
=
request_configs
,
)
if
num_external_hit_tokens
is
None
:
logger
.
info
(
"Reqid: %s, Total tokens %d, LMCache hit tokens: None."
,
request
.
request_id
,
request
.
num_tokens
,
)
return
None
# When prompt length is divisible by the block size and all
# blocks are cached, we need to recompute the last token.
# This will be removed in the future if vLLM's scheduler provides
# a better support for this case.
need_to_allocate
=
num_external_hit_tokens
-
num_computed_tokens
# In, full-prompt-hit case, we need to recompute the last token
if
num_external_hit_tokens
==
request
.
num_tokens
:
need_to_allocate
-=
1
logger
.
info
(
"Reqid: %s, Total tokens %d, LMCache hit tokens: %d, need to load: %d"
,
request
.
request_id
,
request
.
num_tokens
,
num_external_hit_tokens
,
need_to_allocate
,
)
self
.
load_specs
[
request
.
request_id
]
=
LoadSpec
(
vllm_cached_tokens
=
num_computed_tokens
,
lmcache_cached_tokens
=
num_external_hit_tokens
,
can_load
=
False
,
)
if
need_to_allocate
<=
0
:
return
0
return
need_to_allocate
@
_lmcache_nvtx_annotate
def
update_state_after_alloc
(
self
,
request
:
"Request"
,
num_external_tokens
:
int
):
"""
Update KVConnector state after temporary buffer alloc.
For SharedStorageConnector, update _request_needs_load
if the CacheManager this allocated blocks for us.
"""
kv_transfer_params
=
(
request
.
kv_transfer_params
if
hasattr
(
request
,
"kv_transfer_params"
)
else
None
)
if
kv_transfer_params
is
not
None
and
"disagg_spec"
in
kv_transfer_params
:
req_disagg_spec
=
kv_transfer_params
[
"disagg_spec"
]
receiver_id
=
req_disagg_spec
[
"receiver_host"
]
+
str
(
req_disagg_spec
[
"receiver_init_port"
]
)
disagg_spec
=
DisaggSpec
(
req_id
=
req_disagg_spec
[
"req_id"
],
receiver_id
=
receiver_id
,
receiver_host
=
req_disagg_spec
[
"receiver_host"
],
receiver_init_port
=
req_disagg_spec
[
"receiver_init_port"
],
receiver_alloc_port
=
req_disagg_spec
[
"receiver_alloc_port"
],
)
tmp_disagg_tracker
[
request
.
request_id
]
=
disagg_spec
self
.
_unfinished_requests
[
request
.
request_id
]
=
request
if
request
.
request_id
not
in
self
.
load_specs
:
# No KV tokens from external KV cache, return
return
if
num_external_tokens
==
0
:
# No need to load anything
self
.
load_specs
[
request
.
request_id
].
can_load
=
False
return
# Only check for non-prompt-hit case
if
(
self
.
load_specs
[
request
.
request_id
].
lmcache_cached_tokens
!=
request
.
num_tokens
):
assert
(
num_external_tokens
>
0
and
num_external_tokens
==
self
.
load_specs
[
request
.
request_id
].
lmcache_cached_tokens
-
self
.
load_specs
[
request
.
request_id
].
vllm_cached_tokens
),
(
f
"Mismatch in number of tokens:
{
num_external_tokens
}
vs "
f
"
{
self
.
load_specs
[
request
.
request_id
].
lmcache_cached_tokens
}
-"
f
"
{
self
.
load_specs
[
request
.
request_id
].
vllm_cached_tokens
}
"
f
" for request
{
request
.
request_id
}
"
)
self
.
load_specs
[
request
.
request_id
].
can_load
=
True
@
_lmcache_nvtx_annotate
def
build_connector_meta
(
self
,
scheduler_output
:
SchedulerOutput
)
->
KVConnectorMetadata
:
"""Attach the connector metadata to the request object.
This function should NOT modify other fields in the scheduler_output
except the `kv_connector_metadata` field.
Also, calling this function will reset the state of the connector.
Args:
scheduler_output (SchedulerOutput): the scheduler output object.
"""
force_skip_save
=
self
.
kv_role
==
"kv_consumer"
or
self
.
force_skip_save
meta
=
LMCacheConnectorMetadata
()
# set and update lookup requests for unpin
meta
.
lookup_requests_in_step
=
self
.
_lookup_requests_in_step
self
.
_lookup_requests_in_step
=
[]
for
finished_req_id
in
scheduler_output
.
finished_req_ids
:
self
.
_request_trackers
.
pop
(
finished_req_id
,
None
)
self
.
_unfinished_requests
.
pop
(
finished_req_id
,
None
)
for
request
in
scheduler_output
.
scheduled_new_reqs
:
# Right now, we only load KV for new requests
load_spec
=
self
.
load_specs
.
pop
(
request
.
req_id
,
None
)
num_tokens_to_compute
=
(
request
.
num_computed_tokens
+
scheduler_output
.
num_scheduled_tokens
[
request
.
req_id
]
)
lmcache_cached_tokens
=
0
if
load_spec
is
not
None
:
lmcache_cached_tokens
=
load_spec
.
lmcache_cached_tokens
request_priority
=
self
.
_requests_priority
.
pop
(
request
.
req_id
,
0
)
skip_save
=
force_skip_save
or
(
self
.
config
.
priority_limit
is
not
None
and
request_priority
>
self
.
config
.
priority_limit
)
request_tracker
=
RequestTracker
.
from_new_request
(
self
.
config
,
request
,
num_tokens_to_compute
,
lmcache_cached_tokens
,
skip_save
,
)
self
.
_request_trackers
[
request
.
req_id
]
=
request_tracker
req_meta
=
ReqMeta
.
from_request_tracker
(
request_tracker
,
self
.
_block_size
,
self
.
_lmcache_chunk_size
,
load_spec
=
load_spec
,
discard_partial_chunks
=
self
.
_discard_partial_chunks
,
save_decode_cache
=
self
.
_save_decode_cache
,
)
if
req_meta
is
not
None
:
meta
.
add_request
(
req_meta
)
cached_reqs
=
scheduler_output
.
scheduled_cached_reqs
# NOTE: For backward compatibility with vllm version < 0.9.2,
# In the latest vllm version, the type of scheduled_cached_reqs has
# changed from list to object `CachedRequestData`
if
isinstance
(
cached_reqs
,
list
):
for
i
,
req
in
enumerate
(
cached_reqs
):
request_tracker
=
self
.
_request_trackers
[
req
.
req_id
]
request_tracker
.
update
(
req
.
new_token_ids
,
req
.
new_block_ids
)
req_meta
=
ReqMeta
.
from_request_tracker
(
request_tracker
,
self
.
_block_size
,
self
.
_lmcache_chunk_size
,
load_spec
=
None
,
discard_partial_chunks
=
self
.
_discard_partial_chunks
,
)
if
req_meta
is
not
None
:
meta
.
add_request
(
req_meta
)
return
meta
for
i
,
req_id
in
enumerate
(
cached_reqs
.
req_ids
):
request_tracker
=
self
.
_request_trackers
[
req_id
]
num_new_tokens
=
scheduler_output
.
num_scheduled_tokens
[
req_id
]
if
cached_request
:
=
self
.
_unfinished_requests
.
get
(
req_id
):
num_current_tokens
=
len
(
request_tracker
.
token_ids
)
new_token_ids
=
cached_request
.
all_token_ids
[
num_current_tokens
:
num_current_tokens
+
num_new_tokens
]
else
:
raise
ValueError
(
f
"Request
{
req_id
}
is not in _unfinished_requests, "
f
"but it is scheduled to be cached"
)
new_block_ids
=
cached_reqs
.
new_block_ids
[
i
]
request_tracker
.
update
(
new_token_ids
,
new_block_ids
)
req_meta
=
ReqMeta
.
from_request_tracker
(
request_tracker
,
self
.
_block_size
,
self
.
_lmcache_chunk_size
,
load_spec
=
None
,
discard_partial_chunks
=
self
.
_discard_partial_chunks
,
save_decode_cache
=
self
.
_save_decode_cache
,
)
if
req_meta
is
not
None
:
meta
.
add_request
(
req_meta
)
return
meta
@
_lmcache_nvtx_annotate
def
request_finished
(
self
,
request
:
"Request"
,
block_ids
:
list
[
int
],
)
->
tuple
[
bool
,
dict
[
str
,
Any
]
|
None
]:
params
=
(
request
.
kv_transfer_params
if
hasattr
(
request
,
"kv_transfer_params"
)
else
None
)
return_params
=
None
# NOTE: Used to stream back the first token
# for disagg prefill
if
params
is
not
None
and
"ret_first_tok"
in
params
:
return_params
=
{
"first_tok"
:
request
.
_output_token_ids
[
0
],
}
return
False
,
return_params
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment