Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f4fd3a96
Commit
f4fd3a96
authored
Sep 16, 2025
by
yangshj1
Browse files
add pd lmcache
parent
4a80b456
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
84 additions
and
3 deletions
+84
-3
vllm/attention/layer.py
vllm/attention/layer.py
+9
-1
vllm/distributed/kv_transfer/kv_transfer_state.py
vllm/distributed/kv_transfer/kv_transfer_state.py
+17
-1
vllm/v1/core/sched/scheduler.py
vllm/v1/core/sched/scheduler.py
+39
-0
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+16
-0
vllm/zero_overhead/v1/gpu_model_runner.py
vllm/zero_overhead/v1/gpu_model_runner.py
+3
-1
No files found.
vllm/attention/layer.py
View file @
f4fd3a96
...
...
@@ -7,11 +7,13 @@ import torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorRole
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
get_lmcache_connector
,
has_kv_transfer_group
,
is_v1_kv_transfer_group
)
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
...
...
@@ -22,7 +24,7 @@ from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from
vllm.platforms
import
_Backend
,
current_platform
from
vllm.utils
import
direct_register_custom_op
from
vllm.v1.attention.backends.utils
import
validate_kv_sharing_target
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
class
Attention
(
nn
.
Module
):
"""Attention layer.
...
...
@@ -376,6 +378,9 @@ def wait_for_kv_layer_from_connector(layer_name: str):
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
wait_for_layer_load
(
layer_name
)
get_lmcache_connector
().
wait_for_layer_load
(
layer_name
)
def
maybe_save_kv_layer_to_connector
(
layer_name
:
str
,
...
...
@@ -393,6 +398,9 @@ def maybe_save_kv_layer_to_connector(
assert
isinstance
(
attn_metadata
,
dict
)
connector
.
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
get_lmcache_connector
().
save_kv_layer
(
layer_name
,
kv_cache_layer
,
attn_metadata
[
layer_name
])
def
unified_attention
(
...
...
vllm/distributed/kv_transfer/kv_transfer_state.py
View file @
f4fd3a96
...
...
@@ -3,24 +3,32 @@
from
typing
import
TYPE_CHECKING
,
Optional
from
vllm
import
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.base
import
KVConnectorBaseType
from
vllm.distributed.kv_transfer.kv_connector.factory
import
(
KVConnectorFactory
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
(
KVConnectorBase_V1
,
KVConnectorRole
)
from
vllm.distributed.parallel_state
import
get_world_group
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorRole
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
_KV_CONNECTOR_AGENT
:
Optional
[
KVConnectorBaseType
]
=
None
_KV_LMCACHE_CONNECTOR_AGENT
:
Optional
[
KVConnectorBaseType
]
=
None
def
get_kv_transfer_group
()
->
KVConnectorBaseType
:
assert
_KV_CONNECTOR_AGENT
is
not
None
,
(
"disaggregated KV cache transfer parallel group is not initialized"
)
return
_KV_CONNECTOR_AGENT
def
get_lmcache_connector
()
->
LMCacheConnectorV1
:
assert
_KV_LMCACHE_CONNECTOR_AGENT
is
not
None
,
(
"LM cache transfer parallel group is not initialized"
)
return
_KV_LMCACHE_CONNECTOR_AGENT
def
has_kv_transfer_group
()
->
bool
:
return
_KV_CONNECTOR_AGENT
is
not
None
...
...
@@ -54,6 +62,7 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
"""
global
_KV_CONNECTOR_AGENT
global
_KV_LMCACHE_CONNECTOR_AGENT
if
vllm_config
.
kv_transfer_config
is
None
:
return
...
...
@@ -63,6 +72,13 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
if
envs
.
VLLM_USE_V1
:
_KV_CONNECTOR_AGENT
=
KVConnectorFactory
.
create_connector_v1
(
config
=
vllm_config
,
role
=
KVConnectorRole
.
WORKER
)
lmcache_config
=
vllm_config
lmcache_config
.
kv_transfer_config
.
kv_role
=
"kv_both"
lmcache_connector
:
LMCacheConnectorV1
=
LMCacheConnectorV1
(
lmcache_config
,
role
=
KVConnectorRole
.
WORKER
)
_KV_LMCACHE_CONNECTOR_AGENT
=
lmcache_connector
else
:
_KV_CONNECTOR_AGENT
=
KVConnectorFactory
.
create_connector_v0
(
rank
=
get_world_group
().
rank
,
...
...
vllm/v1/core/sched/scheduler.py
View file @
f4fd3a96
...
...
@@ -34,6 +34,7 @@ from vllm.v1.outputs import ModelRunnerOutput
from
vllm.v1.request
import
Request
,
RequestStatus
from
vllm.v1.spec_decode.metrics
import
SpecDecodingStats
from
vllm.v1.structured_output
import
StructuredOutputManager
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
logger
=
init_logger
(
__name__
)
...
...
@@ -86,6 +87,11 @@ class Scheduler(SchedulerInterface):
self
.
connector
=
KVConnectorFactory
.
create_connector_v1
(
config
=
self
.
vllm_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
lmcache_config
=
self
.
vllm_config
lmcache_config
.
kv_transfer_config
.
kv_role
=
"kv_both"
self
.
lmcache_connector
:
LMCacheConnectorV1
=
LMCacheConnectorV1
(
lmcache_config
,
role
=
KVConnectorRole
.
SCHEDULER
)
self
.
kv_event_publisher
=
EventPublisherFactory
.
create
(
self
.
kv_events_config
,
self
.
parallel_config
.
data_parallel_rank
,
...
...
@@ -389,6 +395,12 @@ class Scheduler(SchedulerInterface):
self
.
connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
if
self
.
lmcache_connector
is
not
None
:
num_external_computed_tokens
,
load_kv_async
=
(
self
.
lmcache_connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
# Total computed tokens (local + external).
num_computed_tokens
=
(
num_new_local_computed_tokens
+
num_external_computed_tokens
)
...
...
@@ -463,6 +475,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens
,
)
if
self
.
lmcache_connector
is
not
None
:
self
.
lmcache_connector
.
update_state_after_alloc
(
request
,
new_computed_blocks
+
new_blocks
,
num_external_computed_tokens
,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request
=
self
.
waiting
.
pop_request
()
...
...
@@ -578,6 +597,10 @@ class Scheduler(SchedulerInterface):
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
if
self
.
lmcache_connector
is
not
None
:
meta
=
self
.
lmcache_connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
events
=
self
.
kv_cache_manager
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
...
...
@@ -675,6 +698,11 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens
,
load_kv_async
=
(
self
.
connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
if
self
.
lmcache_connector
is
not
None
:
num_external_computed_tokens
,
load_kv_async
=
(
self
.
lmcache_connector
.
get_num_new_matched_tokens
(
request
,
num_new_local_computed_tokens
))
# Total computed tokens (local + external).
num_computed_tokens
=
(
num_new_local_computed_tokens
+
...
...
@@ -750,6 +778,13 @@ class Scheduler(SchedulerInterface):
num_external_computed_tokens
,
)
if
self
.
lmcache_connector
is
not
None
:
self
.
lmcache_connector
.
update_state_after_alloc
(
request
,
new_computed_blocks
+
new_blocks
,
num_external_computed_tokens
,
)
# Request was already popped from self.waiting
# unless it was re-added above due to new_blocks being None.
request
=
self
.
waiting
.
pop_request
()
...
...
@@ -994,6 +1029,10 @@ class Scheduler(SchedulerInterface):
meta
=
self
.
connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
if
self
.
lmcache_connector
is
not
None
:
meta
=
self
.
lmcache_connector
.
build_connector_meta
(
scheduler_output
)
scheduler_output
.
kv_connector_metadata
=
meta
events
=
self
.
kv_cache_manager
.
take_events
()
if
events
:
batch
=
KVEventBatch
(
ts
=
time
.
time
(),
events
=
events
)
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
f4fd3a96
...
...
@@ -23,6 +23,7 @@ from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config
)
from
vllm.distributed.eplb.eplb_state
import
EplbState
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
get_lmcache_connector
,
has_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.parallel_state
import
(
...
...
@@ -73,6 +74,9 @@ from vllm.two_batch_overlap.v1.model_input_split_v1 import tbo_split_and_execute
from
..sample.logits_processor
import
LogitsProcessorManager
from
.utils
import
(
gather_mm_placeholders
,
initialize_kv_cache_for_kv_sharing
,
sanity_check_mm_encoder_outputs
,
scatter_mm_placeholders
)
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.distributed.kv_transfer.kv_connector.v1.lmcache_connector
import
LMCacheConnectorV1
from
vllm.distributed.kv_transfer.kv_connector.v1.base
import
KVConnectorRole
if
TYPE_CHECKING
:
import
xgrammar
as
xgr
...
...
@@ -1573,6 +1577,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_lmcache_connector
().
clear_connector_metadata
()
self
.
eplb_step
()
return
ModelRunnerOutput
(
...
...
@@ -1736,11 +1742,19 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# Do this here to save a collective_rpc.
kv_connector
.
start_load_kv
(
get_forward_context
())
lmcache_connector
=
get_lmcache_connector
()
lmcache_connector
.
bind_connector_metadata
(
scheduler_output
.
kv_connector_metadata
)
lmcache_connector
.
start_load_kv
(
get_forward_context
())
@
staticmethod
def
maybe_wait_for_kv_save
()
->
None
:
if
has_kv_transfer_group
():
get_kv_transfer_group
().
wait_for_save
()
lmcache_connector
=
get_lmcache_connector
()
lmcache_connector
.
wait_for_save
()
@
staticmethod
def
get_finished_kv_transfers
(
scheduler_output
:
"SchedulerOutput"
,
...
...
@@ -2690,6 +2704,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if
has_kv_transfer_group
():
get_kv_transfer_group
().
register_kv_caches
(
kv_caches
)
get_lmcache_connector
().
register_kv_caches
(
kv_caches
)
def
get_kv_cache_spec
(
self
)
->
dict
[
str
,
KVCacheSpec
]:
"""
Generates the KVCacheSpec by parsing the kv cache format from each
...
...
vllm/zero_overhead/v1/gpu_model_runner.py
View file @
f4fd3a96
...
...
@@ -3,7 +3,7 @@ from typing import Any, Optional, Union
import
torch
import
numpy
as
np
from
vllm
import
envs
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
from
vllm.distributed.kv_transfer.kv_transfer_state
import
get_kv_transfer_group
,
has_kv_transfer_group
,
get_lmcache_connector
from
vllm.distributed.parallel_state
import
get_pp_group
,
get_tp_group
from
vllm.forward_context
import
set_forward_context
from
vllm.sequence
import
IntermediateTensors
...
...
@@ -727,6 +727,8 @@ class V1ZeroModelRunner(GPUModelRunner):
if
has_kv_transfer_group
():
get_kv_transfer_group
().
clear_connector_metadata
()
get_lmcache_connector
().
clear_connector_metadata
()
self
.
eplb_step
()
model_output
=
ZeroV1ModelRunnerOutput
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment