Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
da35d84f
Commit
da35d84f
authored
Feb 24, 2026
by
laibao
Browse files
feat(kvpress): FlashAttention 接入 KV 压缩 hooks
parent
dbcb0376
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
556 additions
and
0 deletions
+556
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+107
-0
vllm/v1/kv_compression/compaction_step.py
vllm/v1/kv_compression/compaction_step.py
+70
-0
vllm/v1/kv_compression/flash_attn_hooks.py
vllm/v1/kv_compression/flash_attn_hooks.py
+192
-0
vllm/v1/kv_compression/prompt_end.py
vllm/v1/kv_compression/prompt_end.py
+187
-0
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
da35d84f
...
@@ -61,6 +61,10 @@ from vllm.v1.attention.backends.utils import (
...
@@ -61,6 +61,10 @@ from vllm.v1.attention.backends.utils import (
)
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.flash_attn_hooks
import
(
maybe_compute_prompt_end_payload_flash_attn
,
maybe_compact_kv_cache_flash_attn
,
)
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -248,6 +252,11 @@ class FlashAttentionMetadata:
...
@@ -248,6 +252,11 @@ class FlashAttentionMetadata:
# |-- query_len ---|
# |-- query_len ---|
num_actual_tokens
:
int
# Number of tokens excluding padding.
num_actual_tokens
:
int
# Number of tokens excluding padding.
# When input tensors are padded (e.g., for sequence parallelism / piecewise
# CUDA graphs), `num_actual_tokens` may include padding. KV compression
# helpers require the unpadded scheduled token count to match
# `query_start_loc[-1]`.
num_unpadded_tokens
:
int
|
None
=
None
max_query_len
:
int
max_query_len
:
int
query_start_loc
:
torch
.
Tensor
query_start_loc
:
torch
.
Tensor
max_seq_len
:
int
max_seq_len
:
int
...
@@ -262,6 +271,17 @@ class FlashAttentionMetadata:
...
@@ -262,6 +271,17 @@ class FlashAttentionMetadata:
prefix_kv_lens
:
torch
.
Tensor
|
None
prefix_kv_lens
:
torch
.
Tensor
|
None
suffix_kv_lens
:
torch
.
Tensor
|
None
suffix_kv_lens
:
torch
.
Tensor
|
None
# KV compression metadata for token-shared selection.
kv_compression_must_keep
:
torch
.
Tensor
|
None
=
None
kv_compression_topk_budget
:
torch
.
Tensor
|
None
=
None
# CPU-known max Top-K budget for this step (avoids device->host sync).
kv_compression_topk_budget_max
:
int
|
None
=
None
# Chunked prefill: prompt-end one-shot scoring/Top-K (scheme 3).
kv_compression_prompt_end
:
torch
.
Tensor
|
None
=
None
# [B] bool
kv_compression_prompt_lens
:
torch
.
Tensor
|
None
=
None
# [B] int32
kv_compression_prompt_topk_keep
:
torch
.
Tensor
|
None
=
None
# [B] int32
kv_compression_prompt_topk_keep_max
:
int
|
None
=
None
# For GQA DCP
# For GQA DCP
max_dcp_context_kv_len
:
int
|
None
=
None
max_dcp_context_kv_len
:
int
|
None
=
None
dcp_context_kv_lens
:
torch
.
Tensor
|
None
=
None
dcp_context_kv_lens
:
torch
.
Tensor
|
None
=
None
...
@@ -546,6 +566,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
...
@@ -546,6 +566,7 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
attn_metadata
=
FlashAttentionMetadata
(
attn_metadata
=
FlashAttentionMetadata
(
num_actual_tokens
=
num_actual_tokens
,
num_actual_tokens
=
num_actual_tokens
,
num_unpadded_tokens
=
common_attn_metadata
.
num_unpadded_tokens
,
max_query_len
=
max_query_len
,
max_query_len
=
max_query_len
,
query_start_loc
=
query_start_loc
,
query_start_loc
=
query_start_loc
,
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
...
@@ -560,6 +581,13 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
...
@@ -560,6 +581,13 @@ class FlashAttentionMetadataBuilder(AttentionMetadataBuilder[FlashAttentionMetad
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
kv_compression_must_keep
=
common_attn_metadata
.
kv_compression_must_keep
,
kv_compression_topk_budget
=
common_attn_metadata
.
kv_compression_topk_budget
,
kv_compression_topk_budget_max
=
common_attn_metadata
.
kv_compression_topk_budget_max
,
kv_compression_prompt_end
=
common_attn_metadata
.
kv_compression_prompt_end
,
kv_compression_prompt_lens
=
common_attn_metadata
.
kv_compression_prompt_lens
,
kv_compression_prompt_topk_keep
=
common_attn_metadata
.
kv_compression_prompt_topk_keep
,
kv_compression_prompt_topk_keep_max
=
common_attn_metadata
.
kv_compression_prompt_topk_keep_max
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
max_num_splits
=
max_num_splits
,
max_num_splits
=
max_num_splits
,
causal
=
causal
,
causal
=
causal
,
...
@@ -729,6 +757,25 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -729,6 +757,25 @@ class FlashAttentionImpl(AttentionImpl):
key_cache
=
key_cache
.
view
(
dtype
)
key_cache
=
key_cache
.
view
(
dtype
)
value_cache
=
value_cache
.
view
(
dtype
)
value_cache
=
value_cache
.
view
(
dtype
)
num_unpadded_tokens
=
(
attn_metadata
.
num_unpadded_tokens
if
attn_metadata
.
num_unpadded_tokens
is
not
None
else
num_actual_tokens
)
cache_block_size
=
int
(
key_cache
.
shape
[
2
]
if
current_platform
.
is_rocm
()
else
key_cache
.
shape
[
1
]
)
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
maybe_compute_prompt_end_payload_flash_attn
(
kv_sharing_target_layer_name
=
self
.
kv_sharing_target_layer_name
,
query
=
query
,
num_actual_tokens
=
num_unpadded_tokens
,
key_cache
=
key_cache
,
cache_block_size
=
cache_block_size
,
attn_metadata
=
attn_metadata
,
sm_scale
=
self
.
scale
,
)
if
not
attn_metadata
.
use_cascade
:
if
not
attn_metadata
.
use_cascade
:
cu_seqlens_q
=
attn_metadata
.
query_start_loc
cu_seqlens_q
=
attn_metadata
.
query_start_loc
seqused_k
=
attn_metadata
.
seq_lens
seqused_k
=
attn_metadata
.
seq_lens
...
@@ -761,6 +808,26 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -761,6 +808,26 @@ class FlashAttentionImpl(AttentionImpl):
k_descale
=
k_descale
,
k_descale
=
k_descale
,
v_descale
=
v_descale
,
v_descale
=
v_descale
,
)
)
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
maybe_compact_kv_cache_flash_attn
(
kv_sharing_target_layer_name
=
self
.
kv_sharing_target_layer_name
,
layer
=
layer
,
query
=
query
,
key
=
key
,
value
=
value
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
num_actual_tokens
=
num_unpadded_tokens
,
cache_block_size
=
cache_block_size
,
attn_metadata
=
attn_metadata
,
sm_scale
=
self
.
scale
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
reshape_and_cache
=
(
reshape_and_cache_cuda
if
current_platform
.
is_rocm
()
else
reshape_and_cache_flash
),
)
return
output
return
output
else
:
else
:
sliding_window_size
=
(
sliding_window_size
=
(
...
@@ -822,6 +889,26 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -822,6 +889,26 @@ class FlashAttentionImpl(AttentionImpl):
num_splits
=
attn_metadata
.
max_num_splits
,
num_splits
=
attn_metadata
.
max_num_splits
,
s_aux
=
self
.
sinks
,
s_aux
=
self
.
sinks
,
)
)
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
maybe_compact_kv_cache_flash_attn
(
kv_sharing_target_layer_name
=
self
.
kv_sharing_target_layer_name
,
layer
=
layer
,
query
=
query
,
key
=
key
,
value
=
value
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
num_actual_tokens
=
num_unpadded_tokens
,
cache_block_size
=
cache_block_size
,
attn_metadata
=
attn_metadata
,
sm_scale
=
self
.
scale
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
reshape_and_cache
=
(
reshape_and_cache_cuda
if
current_platform
.
is_rocm
()
else
reshape_and_cache_flash
),
)
return
output
return
output
# Cascade attention (rare case).
# Cascade attention (rare case).
...
@@ -851,6 +938,26 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -851,6 +938,26 @@ class FlashAttentionImpl(AttentionImpl):
v_descale
=
layer
.
_v_scale
,
v_descale
=
layer
.
_v_scale
,
s_aux
=
self
.
sinks
,
s_aux
=
self
.
sinks
,
)
)
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
maybe_compact_kv_cache_flash_attn
(
kv_sharing_target_layer_name
=
self
.
kv_sharing_target_layer_name
,
layer
=
layer
,
query
=
query
,
key
=
key
,
value
=
value
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
num_actual_tokens
=
num_unpadded_tokens
,
cache_block_size
=
cache_block_size
,
attn_metadata
=
attn_metadata
,
sm_scale
=
self
.
scale
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
reshape_and_cache
=
(
reshape_and_cache_cuda
if
current_platform
.
is_rocm
()
else
reshape_and_cache_flash
),
)
return
output
return
output
def
do_kv_cache_update
(
def
do_kv_cache_update
(
...
...
vllm/v1/kv_compression/compaction_step.py
0 → 100644
View file @
da35d84f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
torch
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.slot_mapping
import
topk_kv_compact_slot_mapping
from
vllm.v1.kv_compression.snapkv_score
import
snapkv_like_token_scores
def
snapkv_window_for_topk_budget
(
*
,
topk_budget
:
torch
.
Tensor
,
# [B] int32
window
:
int
,
)
->
torch
.
Tensor
:
"""Build per-request SnapKV window sizes for mixed batches.
Requests with a zero Top-K budget do not need token scores; setting their
window to 0 lets the Triton scoring kernel early-return.
"""
return
torch
.
where
(
topk_budget
>
0
,
torch
.
full_like
(
topk_budget
,
int
(
window
)),
torch
.
zeros_like
(
topk_budget
),
)
def
compute_compact_dst_slots_for_step
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] for this step
key
:
torch
.
Tensor
,
# [T, Hkv, D] for this step
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
topk_budget_max
:
int
,
max_query_len
:
int
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""Compute per-token KV compaction destinations for one step."""
token_scores
=
None
if
int
(
topk_budget_max
)
>
0
:
w
=
snapkv_window_for_topk_budget
(
topk_budget
=
topk_budget
,
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
),
)
token_scores
=
snapkv_like_token_scores
(
query
=
query
,
key
=
key
,
query_start_loc
=
query_start_loc
,
window
=
w
,
sm_scale
=
float
(
sm_scale
),
)
return
topk_kv_compact_slot_mapping
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
block_size
=
int
(
block_size
),
max_query_len
=
int
(
max_query_len
),
topk_budget_max
=
int
(
topk_budget_max
),
)
vllm/v1/kv_compression/flash_attn_hooks.py
0 → 100644
View file @
da35d84f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Protocol
import
torch
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.platforms
import
current_platform
from
vllm.v1.kv_compression.compaction_step
import
compute_compact_dst_slots_for_step
from
vllm.v1.kv_compression.forward_context
import
(
get_kv_compression_compact_slots
,
get_kv_compression_prompt_payload
,
set_kv_compression_compact_slots
,
set_kv_compression_prompt_payload
,
)
from
vllm.v1.kv_compression.prompt_end
import
compute_prompt_end_indices
from
vllm.v1.kv_compression.slot_mapping
import
kv_compaction_dst_rewrite_mapping
class
_ReshapeAndCacheFn
(
Protocol
):
def
__call__
(
self
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
None
:
...
def
maybe_compute_prompt_end_payload_flash_attn
(
*
,
kv_sharing_target_layer_name
:
Optional
[
str
],
query
:
torch
.
Tensor
,
num_actual_tokens
:
int
,
key_cache
:
torch
.
Tensor
,
cache_block_size
:
int
,
attn_metadata
:
Any
,
sm_scale
:
float
,
)
->
None
:
"""Compute and stash prompt-end Top-K indices for chunked-prefill scheme 3.
The payload is cached in the forward context and later consumed by the
model runner to perform one-shot prompt KV compaction before the first
decode step.
"""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
kv_sharing_target_layer_name
is
not
None
:
return
prompt_end
=
getattr
(
attn_metadata
,
"kv_compression_prompt_end"
,
None
)
prompt_lens
=
getattr
(
attn_metadata
,
"kv_compression_prompt_lens"
,
None
)
topk_keep
=
getattr
(
attn_metadata
,
"kv_compression_prompt_topk_keep"
,
None
)
if
prompt_end
is
None
or
prompt_lens
is
None
or
topk_keep
is
None
:
return
B
=
int
(
prompt_end
.
numel
())
if
B
<=
0
:
return
forward_context
=
get_forward_context
()
if
get_kv_compression_prompt_payload
(
forward_context
)
is
not
None
:
return
payload
=
compute_prompt_end_indices
(
query
=
query
[:
num_actual_tokens
],
key_cache
=
key_cache
,
block_size
=
cache_block_size
,
query_start_loc
=
attn_metadata
.
query_start_loc
[:
B
+
1
],
block_table
=
attn_metadata
.
block_table
[:
B
],
prompt_end
=
prompt_end
,
prompt_lens
=
prompt_lens
,
topk_keep
=
topk_keep
,
topk_keep_max
=
getattr
(
attn_metadata
,
"kv_compression_prompt_topk_keep_max"
,
None
),
sm_scale
=
sm_scale
,
)
if
payload
is
not
None
:
set_kv_compression_prompt_payload
(
forward_context
,
payload
)
def
maybe_compact_kv_cache_flash_attn
(
*
,
kv_sharing_target_layer_name
:
Optional
[
str
],
layer
:
Any
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_actual_tokens
:
int
,
cache_block_size
:
int
,
attn_metadata
:
Any
,
sm_scale
:
float
,
kv_cache_dtype
:
str
,
reshape_and_cache
:
_ReshapeAndCacheFn
,
)
->
None
:
"""Optional per-step KV compaction for scheme 1/2 token-shared selection."""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
kv_sharing_target_layer_name
is
not
None
:
return
must_keep
=
getattr
(
attn_metadata
,
"kv_compression_must_keep"
,
None
)
topk_budget
=
getattr
(
attn_metadata
,
"kv_compression_topk_budget"
,
None
)
if
must_keep
is
None
or
topk_budget
is
None
:
return
B
=
int
(
topk_budget
.
numel
())
if
B
<=
0
:
return
forward_context
=
get_forward_context
()
per_layer_topk
=
envs
.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER
dst
=
get_kv_compression_compact_slots
(
forward_context
,
per_layer_topk
=
per_layer_topk
,
layer
=
layer
,
)
if
dst
is
None
:
topk_budget_max
=
int
(
getattr
(
attn_metadata
,
"kv_compression_topk_budget_max"
,
0
)
or
0
)
dst
=
compute_compact_dst_slots_for_step
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
query_start_loc
=
attn_metadata
.
query_start_loc
[:
B
+
1
],
seq_lens
=
attn_metadata
.
seq_lens
[:
B
],
block_table
=
attn_metadata
.
block_table
[:
B
],
block_size
=
cache_block_size
,
must_keep
=
must_keep
[:
num_actual_tokens
],
topk_budget
=
topk_budget
,
topk_budget_max
=
topk_budget_max
,
max_query_len
=
attn_metadata
.
max_query_len
,
sm_scale
=
sm_scale
,
)
set_kv_compression_compact_slots
(
forward_context
,
per_layer_topk
=
per_layer_topk
,
layer
=
layer
,
dst
=
dst
,
)
if
dst
is
None
:
return
src
=
attn_metadata
.
slot_mapping
[:
num_actual_tokens
]
dst_rewrite
=
kv_compaction_dst_rewrite_mapping
(
dst_slots
=
dst
,
src_slots
=
src
)
if
not
current_platform
.
is_rocm
():
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
return
# ROCm: optionally prefer the optimized reshape-and-cache kernel.
if
(
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
and
key
.
dtype
==
torch
.
float16
):
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
vllm/v1/kv_compression/prompt_end.py
0 → 100644
View file @
da35d84f
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.kv_cache_view
import
paged_k_cache_view_for_triton_gather
from
vllm.v1.kv_compression.snapkv_score
import
snapkv_query_aware_token_scores
from
vllm.v1.kv_compression.topk_select
import
(
_packed_varlen_coords
,
_topk_keep_mask_and_local_rank
)
def
_prompt_end_topk_keep_indices
(
*
,
token_scores
:
torch
.
Tensor
,
# [T] float32
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32 (candidates only)
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
topk_keep_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device
=
token_scores
.
device
B
=
int
(
prompt_lens
.
numel
())
if
B
==
0
:
empty
=
torch
.
empty
((
0
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
int32
)
prompt_lens_i64
=
prompt_lens
.
to
(
torch
.
long
)
cu
=
torch
.
zeros
((
B
+
1
,
),
device
=
device
,
dtype
=
torch
.
long
)
cu
[
1
:]
=
torch
.
cumsum
(
prompt_lens_i64
,
dim
=
0
)
T
=
int
(
token_scores
.
numel
())
if
T
==
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
starts
,
_
,
lengths
,
req_ids
,
pos_in_req
=
_packed_varlen_coords
(
cu_seqlens
=
cu
,
total_tokens
=
T
,
)
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_prefix
,
0
))
suffix
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_suffix
,
0
))
suffix_start
=
(
prompt_lens_i64
-
suffix
).
clamp_min
(
0
)
prefix_len_t
=
prefix_len
.
index_select
(
0
,
req_ids
)
suffix_start_t
=
suffix_start
.
index_select
(
0
,
req_ids
)
must_keep
=
(
pos_in_req
<
prefix_len_t
)
|
(
pos_in_req
>=
suffix_start_t
)
if
keep_last_token
:
last
=
(
prompt_lens_i64
-
1
).
clamp_min
(
0
)
last_t
=
last
.
index_select
(
0
,
req_ids
)
must_keep
|=
pos_in_req
==
last_t
keep_mask
,
local_rank
,
keep_len
=
_topk_keep_mask_and_local_rank
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_keep
,
starts
=
starts
,
lengths
=
lengths
,
req_ids
=
req_ids
,
pos_in_req
=
pos_in_req
,
max_len
=
int
(
prompt_lens_i64
.
max
().
item
()),
topk_budget_max
=
topk_keep_max
,
)
keep_max_len
=
int
(
keep_len
.
max
().
item
())
if
B
>
0
else
0
if
keep_max_len
<=
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
keep_len
idx_sorted
=
torch
.
zeros
((
B
,
keep_max_len
),
device
=
device
,
dtype
=
torch
.
int32
)
lin_out
=
(
req_ids
*
keep_max_len
+
local_rank
).
masked_select
(
keep_mask
)
vals
=
pos_in_req
.
to
(
torch
.
int32
).
masked_select
(
keep_mask
)
idx_sorted
.
view
(
-
1
).
scatter_
(
0
,
lin_out
,
vals
)
return
idx_sorted
,
keep_len
def
compute_prompt_end_indices
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] scheduled tokens for this step
key_cache
:
torch
.
Tensor
,
# layer KV cache view (platform-dependent)
block_size
:
int
,
query_start_loc
:
torch
.
Tensor
,
# [B+1] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
prompt_end
:
torch
.
Tensor
,
# [B] bool
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32
topk_keep_max
:
Optional
[
int
],
sm_scale
:
float
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device
=
query
.
device
if
prompt_end
.
numel
()
==
0
:
return
None
sel
=
torch
.
nonzero
(
prompt_end
,
as_tuple
=
False
).
flatten
()
if
int
(
sel
.
numel
())
==
0
:
return
None
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
)
keep_last
=
bool
(
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
)
protected_prefix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
)
protected_suffix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
)
# Build packed Q window (last `window` queries per selected request).
sel_list
=
sel
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
qsl
=
query_start_loc
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
q_chunks
=
[]
cu_q
=
[
0
]
w_list
=
[]
for
b
in
sel_list
:
s
=
int
(
qsl
[
b
])
e
=
int
(
qsl
[
b
+
1
])
q_len
=
max
(
0
,
e
-
s
)
win
=
min
(
window
,
q_len
)
w_list
.
append
(
int
(
win
))
if
win
>
0
:
q_chunks
.
append
(
query
[
e
-
win
:
e
])
cu_q
.
append
(
cu_q
[
-
1
]
+
int
(
win
))
if
cu_q
[
-
1
]
<=
0
:
return
None
q_packed
=
torch
.
cat
(
q_chunks
,
dim
=
0
)
if
q_chunks
else
query
[:
0
]
cu_seqlens_q
=
torch
.
tensor
(
cu_q
,
device
=
device
,
dtype
=
torch
.
int32
)
w
=
torch
.
tensor
(
w_list
,
device
=
device
,
dtype
=
torch
.
int32
)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel
=
prompt_lens
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
topk_keep_sel
=
topk_keep
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
((
int
(
prompt_lens_sel
.
numel
())
+
1
,
),
device
=
device
,
dtype
=
torch
.
int32
)
if
int
(
prompt_lens_sel
.
numel
())
>
0
:
cu_seqlens_k
[
1
:]
=
torch
.
cumsum
(
prompt_lens_sel
,
dim
=
0
)
block_table_sel
=
block_table
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
key_cache_view
=
paged_k_cache_view_for_triton_gather
(
key_cache
=
key_cache
,
block_size
=
int
(
block_size
),
)
from
vllm.v1.kv_compression.kv_cache_triton
import
(
gather_k_to_packed_triton
)
k_packed
=
gather_k_to_packed_triton
(
key_cache_view
,
block_table_sel
,
prompt_lens_sel
,
cu_seqlens_k
,
)
token_scores
=
snapkv_query_aware_token_scores
(
query
=
q_packed
,
key
=
k_packed
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
window
=
w
,
sm_scale
=
float
(
sm_scale
),
)
idx_sorted
,
keep_len
=
_prompt_end_topk_keep_indices
(
token_scores
=
token_scores
,
prompt_lens
=
prompt_lens_sel
,
topk_keep
=
topk_keep_sel
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
topk_keep_max
=
topk_keep_max
,
)
return
{
"req_indices"
:
sel
.
to
(
torch
.
int32
),
"idx_sorted"
:
idx_sorted
,
# [B_sel, K_max] int32
"keep_len"
:
keep_len
,
# [B_sel] int32
"prompt_lens"
:
prompt_lens_sel
,
# [B_sel] int32
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment