Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54b1bf44
Commit
54b1bf44
authored
Jan 22, 2026
by
laibao
Browse files
feat: kvpress runner 支持 chunked prefill prompt-end 一次性 KV compaction
parent
faf55520
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
185 additions
and
4 deletions
+185
-4
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+185
-4
No files found.
vllm/v1/worker/gpu_model_runner.py
View file @
54b1bf44
...
...
@@ -875,7 +875,6 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# where M is the max_model_len.
token_indices
=
(
positions_np
+
req_indices
*
self
.
input_batch
.
token_ids_cpu
.
shape
[
1
])
# NOTE(woosuk): We use torch.index_select instead of np.take here
# because torch.index_select is much faster than np.take for large
# tensors.
...
...
@@ -1571,6 +1570,154 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
finished_recving
=
finished_recving
,
)
def
_stash_kv_compression_prompt_payload
(
self
)
->
None
:
"""Persist prompt-end compaction indices from the forward context."""
if
(
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
not
self
.
scheduler_config
.
chunked_prefill_enabled
):
return
forward_context
=
get_forward_context
()
payload
=
getattr
(
forward_context
,
"_kv_compression_prompt_payload"
,
None
)
if
payload
is
None
:
return
req_indices
=
payload
.
get
(
"req_indices"
)
idx_sorted
=
payload
.
get
(
"idx_sorted"
)
keep_len
=
payload
.
get
(
"keep_len"
)
prompt_lens
=
payload
.
get
(
"prompt_lens"
)
if
(
req_indices
is
None
or
idx_sorted
is
None
or
keep_len
is
None
or
prompt_lens
is
None
):
return
req_indices_cpu
=
req_indices
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
keep_cpu
=
keep_len
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
prompt_cpu
=
prompt_lens
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
for
i
,
b
in
enumerate
(
req_indices_cpu
):
if
b
<
0
or
b
>=
len
(
self
.
input_batch
.
req_ids
):
continue
req_id
=
self
.
input_batch
.
req_ids
[
b
]
if
req_id
is
None
:
continue
rs
=
self
.
requests
.
get
(
req_id
)
if
rs
is
None
:
continue
rs
.
kv_compression_prompt_idx_sorted
=
idx_sorted
[
i
]
rs
.
kv_compression_prompt_keep_len
=
int
(
keep_cpu
[
i
])
rs
.
kv_compression_prompt_prompt_len
=
int
(
prompt_cpu
[
i
])
def
_maybe_apply_kv_compression_prompt_compaction
(
self
)
->
None
:
"""Apply one-shot prompt KV compaction before the first decode step."""
if
(
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
not
self
.
scheduler_config
.
chunked_prefill_enabled
):
return
pending_req_ids
:
list
[
str
]
=
[]
for
req_id
in
self
.
input_batch
.
req_ids
:
if
req_id
is
None
:
continue
rs
=
self
.
requests
.
get
(
req_id
)
if
rs
is
None
:
continue
if
rs
.
kv_compression_prompt_idx_sorted
is
None
:
continue
# Only apply once the prompt is fully ingested (decode stage).
if
rs
.
num_computed_tokens
<
rs
.
num_prompt_tokens
:
continue
pending_req_ids
.
append
(
req_id
)
if
not
pending_req_ids
:
return
device
=
self
.
device
pending_states
:
list
[
tuple
[
str
,
torch
.
Tensor
,
int
]]
=
[]
for
req_id
in
pending_req_ids
:
rs
=
self
.
requests
[
req_id
]
keep
=
rs
.
kv_compression_prompt_keep_len
idx
=
rs
.
kv_compression_prompt_idx_sorted
if
keep
is
None
or
idx
is
None
:
continue
keep_i
=
int
(
keep
)
if
keep_i
<=
0
:
# No prompt tokens kept; clear and skip.
rs
.
kv_compression_prompt_idx_sorted
=
None
rs
.
kv_compression_prompt_keep_len
=
None
rs
.
kv_compression_prompt_prompt_len
=
None
continue
pending_states
.
append
((
req_id
,
idx
,
keep_i
))
if
not
pending_states
:
return
B
=
len
(
pending_states
)
keep_list
=
[
k
for
_
,
_
,
k
in
pending_states
]
K_max
=
max
(
keep_list
)
idx_batch
=
torch
.
zeros
((
B
,
K_max
),
device
=
device
,
dtype
=
torch
.
int32
)
for
i
,
(
_
,
row
,
k
)
in
enumerate
(
pending_states
):
idx_batch
[
i
,
:
k
]
=
row
[:
k
].
to
(
device
=
device
,
dtype
=
torch
.
int32
)
keep_tensor
=
torch
.
tensor
(
keep_list
,
device
=
device
,
dtype
=
torch
.
int32
)
from
vllm.v1.attention.kv_compression.kv_cache_triton
import
(
front_compact_inplace_fa_triton
,
make_fa_cache_view
)
# Apply compaction to every attention layer's KV cache in-place.
for
group_id
,
kv_cache_group_spec
in
enumerate
(
self
.
kv_cache_config
.
kv_cache_groups
):
max_blocks
=
0
for
req_id
,
_
,
_
in
pending_states
:
rs
=
self
.
requests
[
req_id
]
if
group_id
>=
len
(
rs
.
block_ids
):
continue
max_blocks
=
max
(
max_blocks
,
len
(
rs
.
block_ids
[
group_id
]))
if
max_blocks
==
0
:
continue
block_table_cpu
=
torch
.
zeros
((
B
,
max_blocks
),
dtype
=
torch
.
int32
,
device
=
"cpu"
)
for
i
,
(
req_id
,
_
,
_
)
in
enumerate
(
pending_states
):
rs
=
self
.
requests
[
req_id
]
if
group_id
>=
len
(
rs
.
block_ids
):
continue
ids
=
rs
.
block_ids
[
group_id
]
if
ids
:
block_table_cpu
[
i
,
:
len
(
ids
)]
=
torch
.
tensor
(
ids
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
block_table
=
block_table_cpu
.
to
(
device
=
device
,
non_blocking
=
True
)
for
layer_name
in
kv_cache_group_spec
.
layer_names
:
layer_index
=
self
.
_extract_layer_index
(
layer_name
)
if
layer_index
>=
len
(
self
.
kv_caches
):
continue
kv_cache
=
self
.
kv_caches
[
layer_index
]
if
not
current_platform
.
is_rocm
():
if
not
isinstance
(
kv_cache
,
torch
.
Tensor
):
continue
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
else
:
if
(
not
isinstance
(
kv_cache
,
(
tuple
,
list
))
or
len
(
kv_cache
)
!=
2
):
continue
key_cache
,
value_cache
=
kv_cache
k_view
,
v_view
=
make_fa_cache_view
(
key_cache
=
key_cache
,
value_cache
=
value_cache
)
front_compact_inplace_fa_triton
(
k_view
,
v_view
,
block_table
,
idx_batch
,
keep_tensor
,
)
# Clear pending state after successful compaction.
for
req_id
,
_
,
_
in
pending_states
:
rs
=
self
.
requests
.
get
(
req_id
)
if
rs
is
None
:
continue
rs
.
kv_compression_prompt_idx_sorted
=
None
rs
.
kv_compression_prompt_keep_len
=
None
rs
.
kv_compression_prompt_prompt_len
=
None
@
torch
.
inference_mode
()
def
execute_model
(
self
,
...
...
@@ -1667,7 +1814,23 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
scheduler_output
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
# Chunked prefill (scheme 3): apply one-shot prompt KV compaction before
# the first decode step writes/reads KV at the compressed positions.
self
.
_maybe_apply_kv_compression_prompt_compaction
()
use_tbo
=
(
envs
.
VLLM_ENABLE_TBO
and
scheduler_output
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
)
if
(
use_tbo
and
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
self
.
scheduler_config
.
chunked_prefill_enabled
):
# NOTE: the TBO path does not call `_stash_kv_compression_prompt_payload`
# inside its `set_forward_context`, so scheme-3 prompt-end payloads
# would be dropped and the next-step compaction would never run.
logger
.
warning_once
(
"TBO is currently incompatible with chunked prefill KV "
"compression (scheme 3); running without TBO."
)
use_tbo
=
False
if
use_tbo
:
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
...
...
@@ -1694,6 +1857,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
self
.
_stash_kv_compression_prompt_payload
()
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
...
...
@@ -1719,6 +1883,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
self
.
_stash_kv_compression_prompt_payload
()
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
...
...
@@ -3686,7 +3851,21 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
# If attention doesn't support CUDA Graphs for this batch, but we
# compiled with full CUDA graphs, we have to skip them entirely.
skip_cuda_graphs
=
self
.
full_cuda_graph
and
not
attention_cuda_graphs
if
envs
.
VLLM_ENABLE_TBO
and
scheduler_output
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
:
# Chunked prefill (scheme 3): apply one-shot prompt KV compaction before
# the first decode step writes/reads KV at the compressed positions.
self
.
_maybe_apply_kv_compression_prompt_compaction
()
use_tbo
=
(
envs
.
VLLM_ENABLE_TBO
and
scheduler_output
.
total_num_scheduled_tokens
>=
envs
.
VLLM_TBO_MIN_TOKENS
)
if
(
use_tbo
and
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
self
.
scheduler_config
.
chunked_prefill_enabled
):
logger
.
warning_once
(
"TBO is currently incompatible with chunked prefill KV "
"compression (scheme 3); running without TBO."
)
use_tbo
=
False
if
use_tbo
:
model_output
,
finished_sending
,
finished_recving
=
\
tbo_split_and_execute_model
(
self
,
attn_metadata
,
num_input_tokens
,
num_tokens_across_dp
,
input_ids
,
positions
,
...
...
@@ -3713,6 +3892,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
self
.
_stash_kv_compression_prompt_payload
()
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
...
...
@@ -3738,6 +3918,7 @@ class GPUModelRunnerMTP(GPUModelRunnerBase):
intermediate_tensors
=
intermediate_tensors
,
inputs_embeds
=
inputs_embeds
,
)
self
.
_stash_kv_compression_prompt_payload
()
self
.
maybe_wait_for_kv_save
()
finished_sending
,
finished_recving
=
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment