Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8518b304
Unverified
Commit
8518b304
authored
Jan 23, 2026
by
Nick Hill
Committed by
GitHub
Jan 23, 2026
Browse files
[Model Runner V2] Add KV Connector support (#32742)
Signed-off-by:
Nick Hill
<
nickhill123@gmail.com
>
parent
2d6b5371
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
159 additions
and
16 deletions
+159
-16
vllm/v1/worker/gpu/attn_utils.py
vllm/v1/worker/gpu/attn_utils.py
+2
-1
vllm/v1/worker/gpu/kv_connector.py
vllm/v1/worker/gpu/kv_connector.py
+125
-0
vllm/v1/worker/gpu/model_runner.py
vllm/v1/worker/gpu/model_runner.py
+28
-13
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+4
-2
No files found.
vllm/v1/worker/gpu/attn_utils.py
View file @
8518b304
...
...
@@ -133,10 +133,11 @@ def init_kv_cache(
kv_cache_config
:
KVCacheConfig
,
attn_backends
:
dict
[
str
,
AttentionBackend
],
device
:
torch
.
device
,
)
->
None
:
)
->
dict
[
str
,
torch
.
Tensor
]
:
kv_cache_raw_tensors
=
_allocate_kv_cache
(
kv_cache_config
,
device
)
kv_caches
=
_reshape_kv_cache
(
kv_cache_config
,
kv_cache_raw_tensors
,
attn_backends
)
bind_kv_cache
(
kv_caches
,
forward_context
,
runner_kv_caches
)
return
kv_caches
def
build_attn_metadata
(
...
...
vllm/v1/worker/gpu/kv_connector.py
0 → 100644
View file @
8518b304
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import
copy
from
typing
import
TYPE_CHECKING
import
torch
from
vllm.config
import
VllmConfig
from
vllm.distributed.kv_transfer
import
(
get_kv_transfer_group
,
has_kv_transfer_group
,
kv_transfer_state
,
)
from
vllm.distributed.kv_transfer.kv_connector.utils
import
copy_kv_blocks
from
vllm.forward_context
import
(
get_forward_context
,
is_forward_context_available
,
set_forward_context
,
)
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
KVConnectorOutput
,
ModelRunnerOutput
,
)
if
TYPE_CHECKING
:
from
vllm.v1.core.sched.output
import
SchedulerOutput
class
KVConnector
:
"""KVConnector interface used by GPUModelRunner."""
def
pre_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
pass
def
post_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
,
wait_for_save
:
bool
=
True
)
->
KVConnectorOutput
|
None
:
return
None
def
no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
return
EMPTY_MODEL_RUNNER_OUTPUT
def
set_disabled
(
self
,
disabled
:
bool
)
->
None
:
pass
class
ActiveKVConnector
(
KVConnector
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
kv_caches_dict
:
dict
[
str
,
torch
.
Tensor
]
):
self
.
vllm_config
=
vllm_config
self
.
kv_connector
=
get_kv_transfer_group
()
# Register kv caches with KV Connector if applicable.
# TODO: support cross_layers_kv_cache
# (see https://github.com/vllm-project/vllm/pull/27743)
self
.
kv_connector
.
register_kv_caches
(
kv_caches_dict
)
self
.
kv_connector
.
set_host_xfer_buffer_ops
(
copy_kv_blocks
)
self
.
_disabled
=
False
def
pre_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
None
:
if
self
.
_disabled
:
return
if
scheduler_output
.
preempted_req_ids
:
self
.
kv_connector
.
handle_preemptions
(
scheduler_output
.
preempted_req_ids
)
assert
scheduler_output
.
kv_connector_metadata
is
not
None
self
.
kv_connector
.
bind_connector_metadata
(
scheduler_output
.
kv_connector_metadata
)
# TODO: sort out KV Connectors' use of forward_context
if
is_forward_context_available
():
self
.
kv_connector
.
start_load_kv
(
get_forward_context
())
else
:
with
set_forward_context
(
None
,
self
.
vllm_config
):
self
.
kv_connector
.
start_load_kv
(
get_forward_context
())
def
post_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
,
wait_for_save
:
bool
=
True
)
->
KVConnectorOutput
|
None
:
if
self
.
_disabled
:
return
None
output
=
KVConnectorOutput
()
if
wait_for_save
:
self
.
kv_connector
.
wait_for_save
()
output
.
finished_sending
,
output
.
finished_recving
=
(
self
.
kv_connector
.
get_finished
(
scheduler_output
.
finished_req_ids
)
)
output
.
invalid_block_ids
=
self
.
kv_connector
.
get_block_ids_with_load_errors
()
output
.
kv_connector_stats
=
self
.
kv_connector
.
get_kv_connector_stats
()
output
.
kv_cache_events
=
self
.
kv_connector
.
get_kv_connector_kv_cache_events
()
self
.
kv_connector
.
clear_connector_metadata
()
return
output
def
no_forward
(
self
,
scheduler_output
:
"SchedulerOutput"
)
->
ModelRunnerOutput
:
if
self
.
_disabled
:
return
EMPTY_MODEL_RUNNER_OUTPUT
self
.
pre_forward
(
scheduler_output
)
kv_connector_output
=
self
.
post_forward
(
scheduler_output
,
wait_for_save
=
False
)
if
kv_connector_output
is
None
or
kv_connector_output
.
is_empty
():
return
EMPTY_MODEL_RUNNER_OUTPUT
output
=
copy
.
copy
(
EMPTY_MODEL_RUNNER_OUTPUT
)
output
.
kv_connector_output
=
kv_connector_output
return
output
def
set_disabled
(
self
,
disabled
:
bool
)
->
None
:
# Ensure that layer-wise connector hooks aren't called when disabled.
kv_transfer_state
.
_KV_CONNECTOR_AGENT
=
None
if
disabled
else
self
.
kv_connector
self
.
_disabled
=
disabled
NO_OP_KV_CONNECTOR
=
KVConnector
()
def
get_kv_connector
(
vllm_config
:
VllmConfig
,
kv_caches_dict
:
dict
[
str
,
torch
.
Tensor
]
)
->
KVConnector
:
if
not
has_kv_transfer_group
():
# No-op connector.
return
NO_OP_KV_CONNECTOR
return
ActiveKVConnector
(
vllm_config
,
kv_caches_dict
)
vllm/v1/worker/gpu/model_runner.py
View file @
8518b304
...
...
@@ -20,10 +20,7 @@ from vllm.utils.mem_utils import DeviceMemoryProfiler, format_gib
from
vllm.utils.torch_utils
import
STR_DTYPE_TO_TORCH_DTYPE
from
vllm.v1.core.sched.output
import
GrammarOutput
,
SchedulerOutput
from
vllm.v1.kv_cache_interface
import
KVCacheConfig
from
vllm.v1.outputs
import
(
EMPTY_MODEL_RUNNER_OUTPUT
,
ModelRunnerOutput
,
)
from
vllm.v1.outputs
import
ModelRunnerOutput
from
vllm.v1.worker.gpu.async_utils
import
AsyncOutput
from
vllm.v1.worker.gpu.attn_utils
import
(
build_attn_metadata
,
...
...
@@ -48,6 +45,11 @@ from vllm.v1.worker.gpu.input_batch import (
prepare_pos_seq_lens
,
prepare_prefill_inputs
,
)
from
vllm.v1.worker.gpu.kv_connector
import
(
NO_OP_KV_CONNECTOR
,
KVConnector
,
get_kv_connector
,
)
from
vllm.v1.worker.gpu.mm.encoder_runner
import
EncoderRunner
from
vllm.v1.worker.gpu.mm.mrope_utils
import
MRopeState
from
vllm.v1.worker.gpu.sample.output
import
SamplerOutput
...
...
@@ -57,13 +59,12 @@ from vllm.v1.worker.gpu.spec_decode import init_speculator
from
vllm.v1.worker.gpu.spec_decode.rejection_sample
import
rejection_sample
from
vllm.v1.worker.gpu.states
import
RequestState
from
vllm.v1.worker.gpu.structured_outputs
import
StructuredOutputsWorker
from
vllm.v1.worker.kv_connector_model_runner_mixin
import
KVConnectorModelRunnerMixin
from
vllm.v1.worker.lora_model_runner_mixin
import
LoRAModelRunnerMixin
logger
=
init_logger
(
__name__
)
class
GPUModelRunner
(
LoRAModelRunnerMixin
,
KVConnectorModelRunnerMixin
):
class
GPUModelRunner
(
LoRAModelRunnerMixin
):
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
...
...
@@ -172,6 +173,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
tmp_cu_num_logits
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
tmp_query_start_loc
=
UvaBufferPool
(
self
.
max_num_reqs
+
1
,
torch
.
int32
)
self
.
kv_connector
:
KVConnector
=
NO_OP_KV_CONNECTOR
def
update_max_model_len
(
self
,
max_model_len
:
int
)
->
None
:
self
.
max_model_len
=
max_model_len
self
.
req_states
.
max_model_len
=
max_model_len
...
...
@@ -248,13 +251,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
self
.
kv_caches
:
list
[
torch
.
Tensor
]
=
[]
init_kv_cache
(
kv_caches_dict
=
init_kv_cache
(
self
.
kv_caches
,
self
.
compilation_config
.
static_forward_context
,
self
.
kv_cache_config
,
self
.
attn_backends
,
self
.
device
,
)
self
.
kv_connector
=
get_kv_connector
(
self
.
vllm_config
,
kv_caches_dict
)
# Attention groups are not supported.
self
.
attn_groups
=
[]
# type: ignore
...
...
@@ -291,18 +296,22 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_tokens_per_request
[
-
1
]
+=
num_tokens
%
num_reqs
assert
sum
(
num_tokens_per_request
)
==
num_tokens
num_scheduled_tokens
=
{
f
"_dummy_req_
{
i
}
"
:
n
um_tokens_per_request
[
i
]
for
i
in
range
(
num_reqs
)
f
"_dummy_req_
{
i
}
"
:
n
for
i
,
n
in
enumerate
(
num_tokens_per_request
)
}
dummy_scheduler_output
=
SchedulerOutput
.
make_empty
()
dummy_scheduler_output
.
total_num_scheduled_tokens
=
num_tokens
dummy_scheduler_output
.
num_scheduled_tokens
=
num_scheduled_tokens
# Disable any use of KVConnector for dummy runs.
self
.
kv_connector
.
set_disabled
(
True
)
# Execute the model.
self
.
execute_model
(
dummy_scheduler_output
,
dummy_run
=
True
,
skip_attn_for_dummy_run
=
skip_attn
)
self
.
kv_connector
.
set_disabled
(
False
)
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
=
self
.
execute_model_state
hidden_states
,
input_batch
,
_
=
self
.
execute_model_state
sample_hidden_states
=
hidden_states
[
input_batch
.
logits_indices
]
return
hidden_states
,
sample_hidden_states
...
...
@@ -792,7 +801,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self
.
block_tables
.
apply_staged_writes
()
if
scheduler_output
.
total_num_scheduled_tokens
==
0
:
# No need to run the model.
return
EMPTY_MODEL_RUNNER_OUTPUT
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
return
empty_output
# Get the CUDA graph size. None means no CUDA graph is used.
cudagraph_size
=
self
.
cudagraph_manager
.
get_cudagraph_size
(
...
...
@@ -809,7 +819,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
if
num_tokens_after_padding
==
0
:
# All DP ranks have zero tokens to run.
return
EMPTY_MODEL_RUNNER_OUTPUT
empty_output
=
self
.
kv_connector
.
no_forward
(
scheduler_output
)
return
empty_output
if
not
dummy_run
:
# Common case.
...
...
@@ -860,6 +871,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Run CUDA graph.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
hidden_states
=
self
.
cudagraph_manager
.
run
(
input_batch
.
num_tokens_after_padding
)
...
...
@@ -877,13 +889,15 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cudagraph_runtime_mode
=
CUDAGraphMode
.
NONE
,
num_tokens_across_dp
=
num_tokens_across_dp
,
):
self
.
kv_connector
.
pre_forward
(
scheduler_output
)
hidden_states
=
self
.
model
(
input_ids
=
input_batch
.
input_ids
,
positions
=
positions
,
inputs_embeds
=
input_batch
.
inputs_embeds
,
)
self
.
execute_model_state
=
hidden_states
,
input_batch
kv_connector_output
=
self
.
kv_connector
.
post_forward
(
scheduler_output
)
self
.
execute_model_state
=
hidden_states
,
input_batch
,
kv_connector_output
return
None
@
torch
.
inference_mode
()
...
...
@@ -892,7 +906,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
grammar_output
:
GrammarOutput
|
None
,
)
->
AsyncOutput
|
ModelRunnerOutput
:
assert
self
.
execute_model_state
is
not
None
hidden_states
,
input_batch
=
self
.
execute_model_state
hidden_states
,
input_batch
,
kv_connector_output
=
self
.
execute_model_state
self
.
execute_model_state
=
None
# type: ignore
sampler_output
,
num_sampled
,
num_rejected
=
self
.
sample
(
...
...
@@ -917,6 +931,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_id_to_index
=
{
req_id
:
i
for
i
,
req_id
in
enumerate
(
input_batch
.
req_ids
)},
sampled_token_ids
=
None
,
# type: ignore
prompt_logprobs_dict
=
prompt_logprobs_dict
,
# type: ignore[arg-type]
kv_connector_output
=
kv_connector_output
,
)
async_output
=
AsyncOutput
(
model_runner_output
=
model_runner_output
,
...
...
vllm/v1/worker/gpu_worker.py
View file @
8518b304
...
...
@@ -24,6 +24,7 @@ from vllm.distributed import (
from
vllm.distributed.ec_transfer
import
ensure_ec_transfer_initialized
from
vllm.distributed.kv_transfer
import
(
ensure_kv_transfer_initialized
,
ensure_kv_transfer_shutdown
,
get_kv_transfer_group
,
has_kv_transfer_group
,
)
...
...
@@ -921,8 +922,9 @@ class Worker(WorkerBase):
)
def
shutdown
(
self
)
->
None
:
if
runner
:
=
getattr
(
self
,
"model_runner"
,
None
):
runner
.
ensure_kv_transfer_shutdown
()
# has_kv_transfer_group can be None during interpreter shutdown.
if
ensure_kv_transfer_shutdown
is
not
None
:
ensure_kv_transfer_shutdown
()
if
self
.
profiler
is
not
None
:
self
.
profiler
.
shutdown
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment