Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
863f93e6
Commit
863f93e6
authored
Jan 23, 2026
by
laibao
Browse files
feat: kvpress flash_attn 实现非 chunked Top‑K compaction
parent
a9ebf337
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
351 additions
and
1 deletion
+351
-1
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+351
-1
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
863f93e6
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
,
Tuple
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -33,6 +33,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
...
@@ -33,6 +33,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
make_local_attention_virtual_batches
)
make_local_attention_virtual_batches
)
...
@@ -43,6 +44,7 @@ if TYPE_CHECKING:
...
@@ -43,6 +44,7 @@ if TYPE_CHECKING:
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_DISABLE_SNAPKV_TRITON
:
bool
=
False
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
=
16
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
=
16
...
@@ -592,8 +594,10 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -592,8 +594,10 @@ class FlashAttentionImpl(AttentionImpl):
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
num_actual_tokens
=
attn_metadata
.
num_actual_tokens
if
not
current_platform
.
is_rocm
():
if
not
current_platform
.
is_rocm
():
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
key_cache
,
value_cache
=
kv_cache
.
unbind
(
0
)
cache_block_size
=
key_cache
.
shape
[
-
3
]
else
:
else
:
key_cache
,
value_cache
=
kv_cache
key_cache
,
value_cache
=
kv_cache
cache_block_size
=
key_cache
.
shape
[
-
2
]
if
self
.
kv_sharing_target_layer_name
is
None
:
if
self
.
kv_sharing_target_layer_name
is
None
:
# Reshape the input keys and values and store them in the cache.
# Reshape the input keys and values and store them in the cache.
...
@@ -751,6 +755,130 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -751,6 +755,130 @@ class FlashAttentionImpl(AttentionImpl):
# num_splits=attn_metadata.max_num_splits,
# num_splits=attn_metadata.max_num_splits,
is_prefix_cache
=
True
,
is_prefix_cache
=
True
,
)
)
# Optional KV compaction pass for token-shared KV compression.
# This rewrites a selected subset of newly written KV entries into
# a packed layout for the next step.
if
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
self
.
kv_sharing_target_layer_name
is
None
):
dst
=
None
if
(
attn_metadata
.
kv_compression_must_keep
is
not
None
and
attn_metadata
.
kv_compression_topk_budget
is
not
None
):
forward_context
=
get_forward_context
()
per_layer_topk
=
envs
.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER
if
per_layer_topk
:
layer_name
=
getattr
(
layer
,
"layer_name"
,
None
)
if
layer_name
is
None
:
layer_name
=
str
(
id
(
layer
))
dst_by_layer
=
getattr
(
forward_context
,
"_kv_compression_compact_slots_by_layer"
,
None
)
if
dst_by_layer
is
None
:
dst_by_layer
=
{}
setattr
(
forward_context
,
"_kv_compression_compact_slots_by_layer"
,
dst_by_layer
,
)
dst
=
dst_by_layer
.
get
(
layer_name
)
else
:
dst
=
getattr
(
forward_context
,
"_kv_compression_compact_slots"
,
None
)
if
dst
is
None
:
topk_budget
=
attn_metadata
.
kv_compression_topk_budget
token_scores
:
Optional
[
torch
.
Tensor
]
=
None
# If there is no Top-K budget for any request in this
# step, selection does not depend on token scores.
# Skipping SnapKV scoring avoids unnecessary compute.
topk_budget_max
=
int
(
attn_metadata
.
kv_compression_topk_budget_max
or
0
)
if
topk_budget_max
>
0
:
# Mixed batch optimization: avoid scoring requests
# with a zero Top-K budget by setting their
# per-request window to 0 (kernel early-return).
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
)
w
=
torch
.
where
(
topk_budget
>
0
,
torch
.
full_like
(
topk_budget
,
window
),
torch
.
zeros_like
(
topk_budget
),
)
token_scores
=
_snapkv_like_token_scores
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
query_start_loc
=
attn_metadata
.
query_start_loc
,
window
=
w
,
sm_scale
=
self
.
scale
,
)
dst
=
_topk_kv_compact_slot_mapping
(
token_scores
=
token_scores
,
must_keep
=
attn_metadata
.
kv_compression_must_keep
,
topk_budget
=
topk_budget
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
block_table
=
attn_metadata
.
block_table
,
block_size
=
cache_block_size
,
max_query_len
=
attn_metadata
.
max_query_len
,
topk_budget_max
=
topk_budget_max
,
)
if
per_layer_topk
:
dst_by_layer
[
layer_name
]
=
dst
else
:
setattr
(
forward_context
,
"_kv_compression_compact_slots"
,
dst
)
if
dst
is
not
None
:
src
=
attn_metadata
.
slot_mapping
rewrite_mask
=
(
dst
>=
0
)
&
(
dst
!=
src
)
# Avoid host-side synchronization (`torch.any(...)`) and
# dynamic boolean-indexing gathers. Instead, construct a
# per-token destination mapping where non-rewrite tokens
# are marked as -1, which the cache kernels treat as
# padding and skip.
dst_rewrite
=
torch
.
where
(
rewrite_mask
,
dst
,
-
1
)
def
_writeback
(
dst_mapping
:
torch
.
Tensor
)
->
None
:
if
not
current_platform
.
is_rocm
():
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
dst_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
if
(
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
and
key
.
dtype
==
torch
.
float16
):
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
dst_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
from
vllm.attention.utils.fa_utils
import
(
reshape_and_cache_cuda
)
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
dst_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
_writeback
(
dst_rewrite
)
return
output
return
output
assert
not
use_local_attn
,
(
assert
not
use_local_attn
,
(
...
@@ -1265,3 +1393,225 @@ def cascade_attention(
...
@@ -1265,3 +1393,225 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
)
suffix_lse
)
def
_snapkv_like_token_scores
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D]
key
:
torch
.
Tensor
,
# [T, Hkv, D]
query_start_loc
:
torch
.
Tensor
,
# [B+1]
window
:
Union
[
int
,
torch
.
Tensor
],
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""Compute token-shared SnapKV-like scores for a packed varlen batch.
Scores are computed as the attention mass from the last `window` query
tokens to the earlier keys within the same scheduled segment (per request),
summed across KV heads.
Prefers a Triton implementation when available; falls back to a (slower)
PyTorch reference implementation otherwise.
"""
global
_DISABLE_SNAPKV_TRITON
device
=
query
.
device
T
,
Hq
,
D
=
query
.
shape
Hkv
=
key
.
shape
[
1
]
if
Hq
%
Hkv
!=
0
:
raise
ValueError
(
"Query heads must be a multiple of KV heads."
)
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if
(
HAS_TRITON
and
not
_DISABLE_SNAPKV_TRITON
and
device
.
type
==
"cuda"
and
(
not
current_platform
.
is_rocm
()
or
envs
.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM
)
and
query
.
stride
(
-
1
)
==
1
and
key
.
stride
(
-
1
)
==
1
):
try
:
from
vllm.v1.attention.kv_compression.snapkv_triton
import
(
query_aware_key_scores
)
w
=
int
(
window
)
if
isinstance
(
window
,
int
)
else
window
scores_per_head
=
query_aware_key_scores
(
q
=
query
,
k
=
key
,
cu_seqlens_q
=
query_start_loc
,
cu_seqlens_k
=
query_start_loc
,
w
=
w
,
sm_scale
=
float
(
sm_scale
),
pool
=
False
,
protect_last
=
False
,
normalize
=
False
,
)
token_scores
=
scores_per_head
.
sum
(
dim
=
1
)
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
token_scores
)
except
Exception
as
e
:
_DISABLE_SNAPKV_TRITON
=
True
logger
.
warning
(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s"
,
e
)
group
=
Hq
//
Hkv
# Read boundaries on host (small tensor).
qsl
=
query_start_loc
.
tolist
()
B
=
len
(
qsl
)
-
1
wsl
=
None
if
not
isinstance
(
window
,
int
):
if
int
(
window
.
numel
())
!=
B
:
raise
ValueError
(
"window must be a scalar int or have shape [B]."
)
wsl
=
window
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
scores
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
float32
)
for
b
in
range
(
B
):
s
=
int
(
qsl
[
b
])
e
=
int
(
qsl
[
b
+
1
])
L
=
e
-
s
if
L
<=
0
:
continue
win_b
=
int
(
window
)
if
wsl
is
None
else
int
(
wsl
[
b
])
if
win_b
<=
0
:
continue
win
=
min
(
win_b
,
L
)
k_eff_end
=
L
-
win
if
k_eff_end
<=
0
:
continue
q_win
=
query
[
e
-
win
:
e
]
# [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win
=
q_win
.
reshape
(
win
,
Hkv
,
group
,
D
).
mean
(
dim
=
2
)
# [win, Hkv, D]
k_eff
=
key
[
s
:
s
+
k_eff_end
]
# [k_eff_end, Hkv, D]
qh
=
q_win
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, win, D]
kh
=
k_eff
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, k_eff_end, D]
logits
=
torch
.
matmul
(
qh
,
kh
.
transpose
(
1
,
2
))
*
sm_scale
# [Hkv, win, K]
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
# Sum over (heads, window queries) -> per-key token score.
scores
[
s
:
s
+
k_eff_end
]
=
probs
.
sum
(
dim
=
1
).
sum
(
dim
=
0
)
# Aggregate across tensor-parallel ranks so every rank selects the same
# token indices.
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
scores
)
def
_topk_kv_compact_slot_mapping
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
max_query_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
or
B
==
0
:
return
dst_slots
# Per-request segment boundaries in the packed [T] layout.
# NOTE: `query_start_loc` is already sliced to [B+1] by the model runner.
starts
=
query_start_loc
[:
B
].
to
(
torch
.
long
)
ends
=
query_start_loc
[
1
:
B
+
1
].
to
(
torch
.
long
)
lengths
=
ends
-
starts
# [B]
if
lengths
.
numel
()
==
0
:
return
dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max
=
int
(
max_query_len
)
if
max_query_len
is
not
None
else
int
(
lengths
.
max
().
item
())
if
L_max
<=
0
:
return
dst_slots
# Map each token to its (request, offset-within-request) coordinate.
token_idx
=
torch
.
arange
(
T
,
device
=
device
,
dtype
=
torch
.
long
)
# For monotonic `ends` (cu_seqlens), this returns the request id for each
# token in the packed layout.
# Use right=True so that idx==ends[b] maps to the *next* request (b+1),
# i.e., request segments are [start, end) in the packed layout.
req_ids
=
torch
.
bucketize
(
token_idx
,
ends
,
right
=
True
)
# [T]
start_per_token
=
starts
.
index_select
(
0
,
req_ids
)
pos_in_req
=
token_idx
-
start_per_token
# [T] in [0, L_b)
# Clamp the per-request top-k budget to the number of candidate tokens
# (excluding must_keep).
must_keep_counts
=
torch
.
zeros
(
B
,
device
=
device
,
dtype
=
torch
.
long
)
must_keep_counts
.
scatter_add_
(
0
,
req_ids
,
must_keep
.
to
(
torch
.
long
))
cand_counts
=
(
lengths
.
to
(
torch
.
long
)
-
must_keep_counts
).
clamp_min
(
0
)
k_eff
=
torch
.
minimum
(
topk_budget
.
to
(
torch
.
long
).
clamp_min
(
0
),
cand_counts
)
# Prefer an upper bound from CPU (piecewise graph), to avoid sync.
if
topk_budget_max
is
not
None
:
k_max
=
min
(
int
(
topk_budget_max
),
L_max
)
else
:
k_max
=
int
(
k_eff
.
max
().
item
())
# Build a padded [B, L_max] score matrix for a single batched Top-K call.
# Must-keep and padding positions are set to -inf to avoid selection.
keep_mask
=
must_keep
.
clone
()
if
k_max
>
0
:
if
token_scores
is
None
:
raise
ValueError
(
"token_scores must be provided when k_max > 0."
)
masked_scores
=
token_scores
.
to
(
dtype
=
torch
.
float32
).
masked_fill
(
must_keep
,
float
(
"-inf"
))
scores_flat
=
masked_scores
.
new_full
((
B
*
L_max
,
),
float
(
"-inf"
))
linear
=
req_ids
*
L_max
+
pos_in_req
scores_flat
[
linear
]
=
masked_scores
scores
=
scores_flat
.
view
(
B
,
L_max
)
topk_pos
=
torch
.
topk
(
scores
,
k
=
k_max
,
dim
=
1
).
indices
# [B, k_max]
# Select only the first k_eff[b] entries for each request b.
col_mask
=
torch
.
arange
(
k_max
,
device
=
device
).
unsqueeze
(
0
)
<
k_eff
.
unsqueeze
(
1
)
# [B, k_max]
# Avoid host-side synchronization from dynamic indexing. Instead, mark
# selected tokens via a fixed-size scatter-add.
global_sel
=
starts
.
unsqueeze
(
1
)
+
topk_pos
.
to
(
torch
.
long
)
# [B, k_max]
flat_idx
=
global_sel
.
reshape
(
-
1
).
clamp_
(
0
,
T
-
1
)
flat_val
=
col_mask
.
reshape
(
-
1
).
to
(
torch
.
int32
)
tmp
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
int32
)
tmp
.
scatter_add_
(
0
,
flat_idx
,
flat_val
)
keep_mask
|=
tmp
>
0
# Compute segment-local ranks (0..kept-1) for kept tokens, preserving token
# order within each request, without dynamic indexing (graph-friendly).
keep_prefix
=
torch
.
cumsum
(
keep_mask
.
to
(
torch
.
long
),
dim
=
0
)
# [T]
start_minus_1
=
(
starts
-
1
).
clamp_min
(
0
)
prefix_before_all
=
keep_prefix
.
index_select
(
0
,
start_minus_1
.
to
(
torch
.
long
))
prefix_before
=
torch
.
where
(
starts
>
0
,
prefix_before_all
,
torch
.
zeros_like
(
prefix_before_all
))
# [B]
prefix_before_per_token
=
prefix_before
.
index_select
(
0
,
req_ids
)
# [T]
local_rank
=
keep_prefix
-
prefix_before_per_token
-
1
# [T]
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv
=
(
seq_lens
[:
B
].
to
(
torch
.
long
)
-
lengths
.
to
(
torch
.
long
)).
clamp_min
(
0
)
base_kv_per_token
=
base_kv
.
index_select
(
0
,
req_ids
)
# [T]
dest_pos
=
base_kv_per_token
+
local_rank
# [T]
dest_block_idx
=
dest_pos
//
block_size
dest_off
=
dest_pos
-
dest_block_idx
*
block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks
=
int
(
block_table
.
shape
[
1
])
dest_block_idx_safe
=
dest_block_idx
.
clamp_
(
0
,
max_blocks
-
1
).
to
(
torch
.
long
)
block_nums
=
block_table
[
req_ids
,
dest_block_idx_safe
]
dest_slot
=
block_nums
.
to
(
torch
.
long
)
*
block_size
+
dest_off
dst_slots
=
torch
.
where
(
keep_mask
,
dest_slot
.
to
(
torch
.
int64
),
dst_slots
)
return
dst_slots
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment