Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a9ebf337
Commit
a9ebf337
authored
Jan 22, 2026
by
laibao
Browse files
feat: kvpress flash_attn(scheme 3)生成 prompt-end payload
parent
b6a27380
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
273 additions
and
0 deletions
+273
-0
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+273
-0
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
a9ebf337
...
...
@@ -8,6 +8,7 @@ import numpy as np
import
torch
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
...
...
@@ -646,6 +647,33 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Scheme 3 (chunked prefill): on the last prompt chunk, compute global
# prompt indices (score/topk) and cache them in the forward context for
# the model runner to consume before the first decode step.
if
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
self
.
kv_sharing_target_layer_name
is
None
and
attn_metadata
.
kv_compression_prompt_end
is
not
None
and
attn_metadata
.
kv_compression_prompt_lens
is
not
None
and
attn_metadata
.
kv_compression_prompt_topk_keep
is
not
None
):
forward_context
=
get_forward_context
()
payload
=
getattr
(
forward_context
,
"_kv_compression_prompt_payload"
,
None
)
if
payload
is
None
:
payload
=
_compute_prompt_end_indices
(
query
=
query
[:
num_actual_tokens
],
key_cache
=
key_cache
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
block_table
=
attn_metadata
.
block_table
,
prompt_end
=
attn_metadata
.
kv_compression_prompt_end
,
prompt_lens
=
attn_metadata
.
kv_compression_prompt_lens
,
topk_keep
=
attn_metadata
.
kv_compression_prompt_topk_keep
,
topk_keep_max
=
attn_metadata
.
kv_compression_prompt_topk_keep_max
,
sm_scale
=
self
.
scale
,
)
if
payload
is
not
None
:
setattr
(
forward_context
,
"_kv_compression_prompt_payload"
,
payload
)
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn
=
\
(
self
.
use_irope
and
attn_metadata
.
local_attn_metadata
is
not
None
)
...
...
@@ -781,6 +809,251 @@ class FlashAttentionImpl(AttentionImpl):
return
output
def
_prompt_end_topk_keep_indices
(
*
,
token_scores
:
torch
.
Tensor
,
# [T] float32
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32 (candidates only)
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
topk_keep_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device
=
token_scores
.
device
B
=
int
(
prompt_lens
.
numel
())
if
B
==
0
:
empty
=
torch
.
empty
((
0
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
int32
)
prompt_lens_i64
=
prompt_lens
.
to
(
torch
.
long
)
cu
=
torch
.
zeros
((
B
+
1
,
),
device
=
device
,
dtype
=
torch
.
long
)
cu
[
1
:]
=
torch
.
cumsum
(
prompt_lens_i64
,
dim
=
0
)
starts
=
cu
[:
B
]
ends
=
cu
[
1
:]
T
=
int
(
token_scores
.
numel
())
if
T
==
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
token_idx
=
torch
.
arange
(
T
,
device
=
device
,
dtype
=
torch
.
long
)
req_ids
=
torch
.
bucketize
(
token_idx
,
ends
,
right
=
True
)
# [T]
start_per_token
=
starts
.
index_select
(
0
,
req_ids
)
pos_in_req
=
token_idx
-
start_per_token
# [T]
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_prefix
,
0
))
suffix
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_suffix
,
0
))
suffix_start
=
(
prompt_lens_i64
-
suffix
).
clamp_min
(
0
)
prefix_len_t
=
prefix_len
.
index_select
(
0
,
req_ids
)
suffix_start_t
=
suffix_start
.
index_select
(
0
,
req_ids
)
must_keep
=
(
pos_in_req
<
prefix_len_t
)
|
(
pos_in_req
>=
suffix_start_t
)
if
keep_last_token
:
last
=
(
prompt_lens_i64
-
1
).
clamp_min
(
0
)
last_t
=
last
.
index_select
(
0
,
req_ids
)
must_keep
|=
pos_in_req
==
last_t
cand_counts
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
cand_counts
.
scatter_add_
(
0
,
req_ids
,
(
~
must_keep
).
to
(
torch
.
long
))
k_eff
=
torch
.
minimum
(
topk_keep
.
to
(
torch
.
long
).
clamp_min
(
0
),
cand_counts
)
# CPU-known bound avoids a device->host sync; clamp for safety.
if
topk_keep_max
is
None
:
k_max
=
int
(
k_eff
.
max
().
item
())
else
:
k_max
=
int
(
topk_keep_max
)
if
k_max
<
0
:
k_max
=
0
keep_mask
=
must_keep
.
clone
()
if
k_max
>
0
:
L_max
=
int
(
prompt_lens_i64
.
max
().
item
())
masked_scores
=
token_scores
.
to
(
torch
.
float32
).
masked_fill
(
must_keep
,
float
(
"-inf"
))
scores_flat
=
masked_scores
.
new_full
((
B
*
L_max
,
),
float
(
"-inf"
))
linear
=
req_ids
*
L_max
+
pos_in_req
scores_flat
[
linear
]
=
masked_scores
scores
=
scores_flat
.
view
(
B
,
L_max
)
topk_pos
=
torch
.
topk
(
scores
,
k
=
k_max
,
dim
=
1
).
indices
# [B, k_max]
col_mask
=
(
torch
.
arange
(
k_max
,
device
=
device
).
unsqueeze
(
0
)
<
k_eff
.
unsqueeze
(
1
))
global_sel
=
starts
.
unsqueeze
(
1
)
+
topk_pos
.
to
(
torch
.
long
)
# [B,k_max]
flat_idx
=
global_sel
.
reshape
(
-
1
).
clamp_
(
0
,
T
-
1
)
flat_val
=
col_mask
.
reshape
(
-
1
).
to
(
torch
.
int32
)
tmp
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
int32
)
tmp
.
scatter_add_
(
0
,
flat_idx
,
flat_val
)
keep_mask
|=
tmp
>
0
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
.
scatter_add_
(
0
,
req_ids
,
keep_mask
.
to
(
torch
.
long
))
keep_max_len
=
int
(
keep_len
.
max
().
item
())
if
B
>
0
else
0
if
keep_max_len
<=
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
keep_len
.
to
(
torch
.
int32
)
# Stable, order-preserving index list using segment-local ranks.
keep_prefix
=
torch
.
cumsum
(
keep_mask
.
to
(
torch
.
long
),
dim
=
0
)
# [T]
start_minus_1
=
(
starts
-
1
).
clamp_min
(
0
)
prefix_before_all
=
keep_prefix
.
index_select
(
0
,
start_minus_1
)
prefix_before
=
torch
.
where
(
starts
>
0
,
prefix_before_all
,
torch
.
zeros_like
(
prefix_before_all
))
prefix_before_t
=
prefix_before
.
index_select
(
0
,
req_ids
)
local_rank
=
keep_prefix
-
prefix_before_t
-
1
# [T]
idx_sorted
=
torch
.
zeros
((
B
,
keep_max_len
),
device
=
device
,
dtype
=
torch
.
int32
)
lin_out
=
(
req_ids
*
keep_max_len
+
local_rank
).
masked_select
(
keep_mask
)
vals
=
pos_in_req
.
to
(
torch
.
int32
).
masked_select
(
keep_mask
)
idx_sorted
.
view
(
-
1
).
scatter_
(
0
,
lin_out
,
vals
)
return
idx_sorted
,
keep_len
.
to
(
torch
.
int32
)
def
_compute_prompt_end_indices
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] scheduled tokens for this step
key_cache
:
torch
.
Tensor
,
# layer KV cache view (platform-dependent)
query_start_loc
:
torch
.
Tensor
,
# [B+1] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
prompt_end
:
torch
.
Tensor
,
# [B] bool
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32
topk_keep_max
:
Optional
[
int
],
sm_scale
:
float
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device
=
query
.
device
if
prompt_end
.
numel
()
==
0
:
return
None
sel
=
torch
.
nonzero
(
prompt_end
,
as_tuple
=
False
).
flatten
()
if
int
(
sel
.
numel
())
==
0
:
return
None
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
)
keep_last
=
bool
(
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
)
protected_prefix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
)
protected_suffix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
)
# Build packed Q window (last `window` queries per selected request).
sel_list
=
sel
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
qsl
=
query_start_loc
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
q_chunks
=
[]
cu_q
=
[
0
]
w_list
=
[]
for
b
in
sel_list
:
s
=
int
(
qsl
[
b
])
e
=
int
(
qsl
[
b
+
1
])
q_len
=
max
(
0
,
e
-
s
)
win
=
min
(
window
,
q_len
)
w_list
.
append
(
int
(
win
))
if
win
>
0
:
q_chunks
.
append
(
query
[
e
-
win
:
e
])
cu_q
.
append
(
cu_q
[
-
1
]
+
int
(
win
))
if
cu_q
[
-
1
]
<=
0
:
return
None
q_packed
=
torch
.
cat
(
q_chunks
,
dim
=
0
)
if
q_chunks
else
query
[:
0
]
cu_seqlens_q
=
torch
.
tensor
(
cu_q
,
device
=
device
,
dtype
=
torch
.
int32
)
w
=
torch
.
tensor
(
w_list
,
device
=
device
,
dtype
=
torch
.
int32
)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel
=
prompt_lens
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
topk_keep_sel
=
topk_keep
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
((
int
(
prompt_lens_sel
.
numel
())
+
1
,
),
device
=
device
,
dtype
=
torch
.
int32
)
if
int
(
prompt_lens_sel
.
numel
())
>
0
:
cu_seqlens_k
[
1
:]
=
torch
.
cumsum
(
prompt_lens_sel
,
dim
=
0
)
block_table_sel
=
block_table
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
if
not
current_platform
.
is_rocm
():
# CUDA cache view: [num_blocks, block_size, H, D] -> [num_blocks, H, block_size, D]
key_cache_view
=
key_cache
.
permute
(
0
,
2
,
1
,
3
)
else
:
key_cache_view
=
key_cache
from
vllm.v1.attention.kv_compression.kv_cache_triton
import
(
gather_k_to_packed_triton
)
k_packed
=
gather_k_to_packed_triton
(
key_cache_view
,
block_table_sel
,
prompt_lens_sel
,
cu_seqlens_k
,
)
# SnapKV Triton scoring (token-shared via sum over KV heads).
from
vllm.v1.attention.kv_compression.snapkv_triton
import
(
query_aware_key_scores
)
try
:
scores_per_head
=
query_aware_key_scores
(
q
=
q_packed
,
k
=
k_packed
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
w
=
w
,
sm_scale
=
float
(
sm_scale
),
pool
=
False
,
protect_last
=
False
,
normalize
=
False
,
)
token_scores
=
scores_per_head
.
sum
(
dim
=
1
)
except
Exception
:
# Fallback: PyTorch reference scoring (slow but correctness-oriented).
Hq
=
q_packed
.
shape
[
1
]
Hk
=
k_packed
.
shape
[
1
]
D
=
q_packed
.
shape
[
2
]
if
Hq
%
Hk
!=
0
:
raise
group
=
Hq
//
Hk
token_scores
=
torch
.
zeros
((
k_packed
.
shape
[
0
],
),
device
=
device
,
dtype
=
torch
.
float32
)
for
i
in
range
(
len
(
sel_list
)):
qs
=
int
(
cu_q
[
i
])
qe
=
int
(
cu_q
[
i
+
1
])
ks
=
int
(
cu_seqlens_k
[
i
].
item
())
ke
=
int
(
cu_seqlens_k
[
i
+
1
].
item
())
if
qe
<=
qs
or
ke
<=
ks
:
continue
q_win
=
q_packed
[
qs
:
qe
]
# [win, Hq, D]
q_win
=
q_win
.
reshape
(
q_win
.
shape
[
0
],
Hk
,
group
,
D
).
mean
(
dim
=
2
)
k_all
=
k_packed
[
ks
:
ke
]
qh
=
q_win
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
kh
=
k_all
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
logits
=
torch
.
matmul
(
qh
,
kh
.
transpose
(
1
,
2
))
*
float
(
sm_scale
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
token_scores
[
ks
:
ke
]
=
probs
.
sum
(
dim
=
1
).
sum
(
dim
=
0
)
from
vllm.distributed.parallel_state
import
get_tp_group
token_scores
=
get_tp_group
().
all_reduce
(
token_scores
)
idx_sorted
,
keep_len
=
_prompt_end_topk_keep_indices
(
token_scores
=
token_scores
,
prompt_lens
=
prompt_lens_sel
,
topk_keep
=
topk_keep_sel
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
topk_keep_max
=
topk_keep_max
,
)
return
{
"req_indices"
:
sel
.
to
(
torch
.
int32
),
"idx_sorted"
:
idx_sorted
,
# [B_sel, K_max] int32
"keep_len"
:
keep_len
,
# [B_sel] int32
"prompt_lens"
:
prompt_lens_sel
,
# [B_sel] int32
}
def
use_cascade_attention
(
common_prefix_len
:
int
,
query_lens
:
np
.
ndarray
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment