Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3adc766e
Commit
3adc766e
authored
Jan 27, 2026
by
laibao
Browse files
refactor: 抽离 flash_attn 的 KV compression 逻辑到 vllm/v1/kv_compression
parent
9db5ff3b
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1076 additions
and
656 deletions
+1076
-656
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+46
-656
vllm/v1/kv_compression/compaction_step.py
vllm/v1/kv_compression/compaction_step.py
+70
-0
vllm/v1/kv_compression/flash_attn_hooks.py
vllm/v1/kv_compression/flash_attn_hooks.py
+186
-0
vllm/v1/kv_compression/forward_context.py
vllm/v1/kv_compression/forward_context.py
+66
-0
vllm/v1/kv_compression/kv_cache_view.py
vllm/v1/kv_compression/kv_cache_view.py
+38
-0
vllm/v1/kv_compression/metadata.py
vllm/v1/kv_compression/metadata.py
+75
-0
vllm/v1/kv_compression/prompt_end.py
vllm/v1/kv_compression/prompt_end.py
+187
-0
vllm/v1/kv_compression/slot_mapping.py
vllm/v1/kv_compression/slot_mapping.py
+127
-0
vllm/v1/kv_compression/snapkv_score.py
vllm/v1/kv_compression/snapkv_score.py
+150
-0
vllm/v1/kv_compression/topk_select.py
vllm/v1/kv_compression/topk_select.py
+131
-0
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
3adc766e
...
@@ -2,13 +2,12 @@
...
@@ -2,13 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention layer with FlashAttention."""
"""Attention layer with FlashAttention."""
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
,
Tuple
,
Union
from
typing
import
TYPE_CHECKING
,
Any
,
ClassVar
,
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionImpl
,
AttentionMetadata
,
AttentionType
,
AttentionMetadata
,
AttentionType
,
...
@@ -33,18 +32,21 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
...
@@ -33,18 +32,21 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
cdiv
from
vllm.utils
import
cdiv
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.v1.attention.backends.utils
import
(
from
vllm.v1.attention.backends.utils
import
(
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
AttentionMetadataBuilder
,
CommonAttentionMetadata
,
get_kv_cache_layout
,
make_local_attention_virtual_batches
)
make_local_attention_virtual_batches
)
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_cache_interface
import
AttentionSpec
from
vllm.v1.kv_compression.flash_attn_hooks
import
(
maybe_compact_kv_cache_flash_attn
,
maybe_compute_prompt_end_payload_flash_attn
,
)
from
vllm.v1.kv_compression.metadata
import
build_kv_compression_attn_metadata
from
vllm.v1.worker.block_table
import
BlockTable
from
vllm.v1.worker.block_table
import
BlockTable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
from
vllm.v1.worker.gpu_model_runner
import
GPUModelRunner
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_DISABLE_SNAPKV_TRITON
:
bool
=
False
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
# NOTE(woosuk): This is an arbitrary number. Tune it if needed.
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
=
16
_DEFAULT_MAX_NUM_SPLITS_FOR_CUDA_GRAPH
=
16
...
@@ -280,37 +282,11 @@ class FlashAttentionMetadataBuilder(
...
@@ -280,37 +282,11 @@ class FlashAttentionMetadataBuilder(
block_table
.
slot_mapping
[
num_actual_tokens
:].
fill_
(
-
1
)
block_table
.
slot_mapping
[
num_actual_tokens
:].
fill_
(
-
1
)
slot_mapping
=
block_table
.
slot_mapping
[:
num_actual_tokens
]
slot_mapping
=
block_table
.
slot_mapping
[:
num_actual_tokens
]
kv_meta
=
build_kv_compression_attn_metadata
(
kv_compression_must_keep
=
None
runner
=
self
.
runner
,
kv_compression_topk_budget
=
None
num_reqs
=
num_reqs
,
kv_compression_topk_budget_max
:
Optional
[
int
]
=
None
num_actual_tokens
=
num_actual_tokens
,
kv_compression_prompt_end
=
None
)
kv_compression_prompt_lens
=
None
kv_compression_prompt_topk_keep
=
None
kv_compression_prompt_topk_keep_max
:
Optional
[
int
]
=
None
if
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
self
.
runner
.
kv_compression_needs_compaction
):
kv_compression_must_keep
=
self
.
runner
.
kv_compression_must_keep
[:
num_actual_tokens
]
kv_compression_topk_budget
=
self
.
runner
.
kv_compression_topk_budget
[:
num_reqs
]
# Avoid device->host sync by reading from the CPU staging buffer.
if
num_reqs
>
0
:
kv_compression_topk_budget_max
=
int
(
self
.
runner
.
kv_compression_topk_budget_np
[:
num_reqs
].
max
())
else
:
kv_compression_topk_budget_max
=
0
elif
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
and
self
.
runner
.
scheduler_config
.
chunked_prefill_enabled
):
# Scheme 3: compute global prompt indices only on the last prefill
# chunk (per request), and perform the actual cache compaction
# before the first decode step.
if
num_reqs
>
0
and
self
.
runner
.
kv_compression_prompt_end_np
[:
num_reqs
].
any
():
kv_compression_prompt_end
=
self
.
runner
.
kv_compression_prompt_end
[:
num_reqs
]
kv_compression_prompt_lens
=
self
.
runner
.
kv_compression_prompt_lens
[:
num_reqs
]
kv_compression_prompt_topk_keep
=
self
.
runner
.
kv_compression_prompt_topk_keep
[:
num_reqs
]
kv_compression_prompt_topk_keep_max
=
int
(
self
.
runner
.
kv_compression_prompt_topk_keep_max
or
0
)
if
self
.
aot_sliding_window
is
None
:
if
self
.
aot_sliding_window
is
None
:
self
.
aot_sliding_window
=
(
-
1
,
-
1
)
self
.
aot_sliding_window
=
(
-
1
,
-
1
)
...
@@ -470,13 +446,13 @@ class FlashAttentionMetadataBuilder(
...
@@ -470,13 +446,13 @@ class FlashAttentionMetadataBuilder(
cu_prefix_query_lens
=
cu_prefix_query_lens
,
cu_prefix_query_lens
=
cu_prefix_query_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
prefix_kv_lens
=
prefix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
suffix_kv_lens
=
suffix_kv_lens
,
kv_compression_must_keep
=
kv_
compression_
must_keep
,
kv_compression_must_keep
=
kv_
meta
.
must_keep
,
kv_compression_topk_budget
=
kv_
compression_
topk_budget
,
kv_compression_topk_budget
=
kv_
meta
.
topk_budget
,
kv_compression_topk_budget_max
=
kv_
compression_
topk_budget_max
,
kv_compression_topk_budget_max
=
kv_
meta
.
topk_budget_max
,
kv_compression_prompt_end
=
kv_
compression_
prompt_end
,
kv_compression_prompt_end
=
kv_
meta
.
prompt_end
,
kv_compression_prompt_lens
=
kv_
compression_
prompt_lens
,
kv_compression_prompt_lens
=
kv_
meta
.
prompt_lens
,
kv_compression_prompt_topk_keep
=
kv_
compression_
prompt_topk_keep
,
kv_compression_prompt_topk_keep
=
kv_
meta
.
prompt_topk_keep
,
kv_compression_prompt_topk_keep_max
=
kv_
compression_
prompt_topk_keep_max
,
kv_compression_prompt_topk_keep_max
=
kv_
meta
.
prompt_topk_keep_max
,
local_attn_metadata
=
local_attn_metadata
,
local_attn_metadata
=
local_attn_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
prefix_scheduler_metadata
=
prefix_scheduler_metadata
,
max_num_splits
=
max_num_splits
,
max_num_splits
=
max_num_splits
,
...
@@ -651,32 +627,16 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -651,32 +627,16 @@ class FlashAttentionImpl(AttentionImpl):
layer
.
_q_scale
)
layer
.
_q_scale
)
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
query
=
query
.
reshape
((
num_tokens
,
num_heads
,
head_size
))
# Scheme 3 (chunked prefill): on the last prompt chunk, compute global
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
# prompt indices (score/topk) and cache them in the forward context for
maybe_compute_prompt_end_payload_flash_attn
(
# the model runner to consume before the first decode step.
kv_sharing_target_layer_name
=
self
.
kv_sharing_target_layer_name
,
if
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
query
=
query
,
and
self
.
kv_sharing_target_layer_name
is
None
num_actual_tokens
=
num_actual_tokens
,
and
attn_metadata
.
kv_compression_prompt_end
is
not
None
key_cache
=
key_cache
,
and
attn_metadata
.
kv_compression_prompt_lens
is
not
None
cache_block_size
=
cache_block_size
,
and
attn_metadata
.
kv_compression_prompt_topk_keep
is
not
None
):
attn_metadata
=
attn_metadata
,
forward_context
=
get_forward_context
()
sm_scale
=
self
.
scale
,
payload
=
getattr
(
forward_context
,
"_kv_compression_prompt_payload"
,
)
None
)
if
payload
is
None
:
payload
=
_compute_prompt_end_indices
(
query
=
query
[:
num_actual_tokens
],
key_cache
=
key_cache
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
block_table
=
attn_metadata
.
block_table
,
prompt_end
=
attn_metadata
.
kv_compression_prompt_end
,
prompt_lens
=
attn_metadata
.
kv_compression_prompt_lens
,
topk_keep
=
attn_metadata
.
kv_compression_prompt_topk_keep
,
topk_keep_max
=
attn_metadata
.
kv_compression_prompt_topk_keep_max
,
sm_scale
=
self
.
scale
,
)
if
payload
is
not
None
:
setattr
(
forward_context
,
"_kv_compression_prompt_payload"
,
payload
)
# Compute attention and update output up to `num_actual_tokens`.
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn
=
\
use_local_attn
=
\
...
@@ -758,127 +718,24 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -758,127 +718,24 @@ class FlashAttentionImpl(AttentionImpl):
# Optional KV compaction pass for token-shared KV compression.
# Optional KV compaction pass for token-shared KV compression.
# This rewrites a selected subset of newly written KV entries into
# This rewrites a selected subset of newly written KV entries into
# a packed layout for the next step.
# a packed layout for the next step.
if
(
envs
.
VLLM_ENABLE_KV_COMPRESSION
if
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
and
self
.
kv_sharing_target_layer_name
is
None
):
maybe_compact_kv_cache_flash_attn
(
dst
=
None
kv_sharing_target_layer_name
=
self
.
kv_sharing_target_layer_name
,
if
(
attn_metadata
.
kv_compression_must_keep
is
not
None
layer
=
layer
,
and
attn_metadata
.
kv_compression_topk_budget
query
=
query
,
is
not
None
):
key
=
key
,
forward_context
=
get_forward_context
()
value
=
value
,
per_layer_topk
=
envs
.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER
key_cache
=
key_cache
,
if
per_layer_topk
:
value_cache
=
value_cache
,
layer_name
=
getattr
(
layer
,
"layer_name"
,
None
)
num_actual_tokens
=
num_actual_tokens
,
if
layer_name
is
None
:
cache_block_size
=
cache_block_size
,
layer_name
=
str
(
id
(
layer
))
attn_metadata
=
attn_metadata
,
dst_by_layer
=
getattr
(
sm_scale
=
self
.
scale
,
forward_context
,
"_kv_compression_compact_slots_by_layer"
,
kv_cache_dtype
=
self
.
kv_cache_dtype
,
None
)
reshape_and_cache
=
(
reshape_and_cache_cuda
if
dst_by_layer
is
None
:
if
current_platform
.
is_rocm
()
else
dst_by_layer
=
{}
reshape_and_cache_flash
),
setattr
(
)
forward_context
,
"_kv_compression_compact_slots_by_layer"
,
dst_by_layer
,
)
dst
=
dst_by_layer
.
get
(
layer_name
)
else
:
dst
=
getattr
(
forward_context
,
"_kv_compression_compact_slots"
,
None
)
if
dst
is
None
:
topk_budget
=
attn_metadata
.
kv_compression_topk_budget
token_scores
:
Optional
[
torch
.
Tensor
]
=
None
# If there is no Top-K budget for any request in this
# step, selection does not depend on token scores.
# Skipping SnapKV scoring avoids unnecessary compute.
topk_budget_max
=
int
(
attn_metadata
.
kv_compression_topk_budget_max
or
0
)
if
topk_budget_max
>
0
:
# Mixed batch optimization: avoid scoring requests
# with a zero Top-K budget by setting their
# per-request window to 0 (kernel early-return).
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
)
w
=
torch
.
where
(
topk_budget
>
0
,
torch
.
full_like
(
topk_budget
,
window
),
torch
.
zeros_like
(
topk_budget
),
)
token_scores
=
_snapkv_like_token_scores
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
query_start_loc
=
attn_metadata
.
query_start_loc
,
window
=
w
,
sm_scale
=
self
.
scale
,
)
dst
=
_topk_kv_compact_slot_mapping
(
token_scores
=
token_scores
,
must_keep
=
attn_metadata
.
kv_compression_must_keep
,
topk_budget
=
topk_budget
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
block_table
=
attn_metadata
.
block_table
,
block_size
=
cache_block_size
,
max_query_len
=
attn_metadata
.
max_query_len
,
topk_budget_max
=
topk_budget_max
,
)
if
per_layer_topk
:
dst_by_layer
[
layer_name
]
=
dst
else
:
setattr
(
forward_context
,
"_kv_compression_compact_slots"
,
dst
)
if
dst
is
not
None
:
src
=
attn_metadata
.
slot_mapping
rewrite_mask
=
(
dst
>=
0
)
&
(
dst
!=
src
)
# Avoid host-side synchronization (`torch.any(...)`) and
# dynamic boolean-indexing gathers. Instead, construct a
# per-token destination mapping where non-rewrite tokens
# are marked as -1, which the cache kernels treat as
# padding and skip.
dst_rewrite
=
torch
.
where
(
rewrite_mask
,
dst
,
-
1
)
def
_writeback
(
dst_mapping
:
torch
.
Tensor
)
->
None
:
if
not
current_platform
.
is_rocm
():
reshape_and_cache_flash
(
key
,
value
,
key_cache
,
value_cache
,
dst_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
if
(
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
and
key
.
dtype
==
torch
.
float16
):
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
dst_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
from
vllm.attention.utils.fa_utils
import
(
reshape_and_cache_cuda
)
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
dst_mapping
,
self
.
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
_writeback
(
dst_rewrite
)
return
output
return
output
assert
not
use_local_attn
,
(
assert
not
use_local_attn
,
(
...
@@ -937,251 +794,6 @@ class FlashAttentionImpl(AttentionImpl):
...
@@ -937,251 +794,6 @@ class FlashAttentionImpl(AttentionImpl):
return
output
return
output
def
_prompt_end_topk_keep_indices
(
*
,
token_scores
:
torch
.
Tensor
,
# [T] float32
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32 (candidates only)
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
topk_keep_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device
=
token_scores
.
device
B
=
int
(
prompt_lens
.
numel
())
if
B
==
0
:
empty
=
torch
.
empty
((
0
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
int32
)
prompt_lens_i64
=
prompt_lens
.
to
(
torch
.
long
)
cu
=
torch
.
zeros
((
B
+
1
,
),
device
=
device
,
dtype
=
torch
.
long
)
cu
[
1
:]
=
torch
.
cumsum
(
prompt_lens_i64
,
dim
=
0
)
starts
=
cu
[:
B
]
ends
=
cu
[
1
:]
T
=
int
(
token_scores
.
numel
())
if
T
==
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
token_idx
=
torch
.
arange
(
T
,
device
=
device
,
dtype
=
torch
.
long
)
req_ids
=
torch
.
bucketize
(
token_idx
,
ends
,
right
=
True
)
# [T]
start_per_token
=
starts
.
index_select
(
0
,
req_ids
)
pos_in_req
=
token_idx
-
start_per_token
# [T]
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_prefix
,
0
))
suffix
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_suffix
,
0
))
suffix_start
=
(
prompt_lens_i64
-
suffix
).
clamp_min
(
0
)
prefix_len_t
=
prefix_len
.
index_select
(
0
,
req_ids
)
suffix_start_t
=
suffix_start
.
index_select
(
0
,
req_ids
)
must_keep
=
(
pos_in_req
<
prefix_len_t
)
|
(
pos_in_req
>=
suffix_start_t
)
if
keep_last_token
:
last
=
(
prompt_lens_i64
-
1
).
clamp_min
(
0
)
last_t
=
last
.
index_select
(
0
,
req_ids
)
must_keep
|=
pos_in_req
==
last_t
cand_counts
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
cand_counts
.
scatter_add_
(
0
,
req_ids
,
(
~
must_keep
).
to
(
torch
.
long
))
k_eff
=
torch
.
minimum
(
topk_keep
.
to
(
torch
.
long
).
clamp_min
(
0
),
cand_counts
)
# CPU-known bound avoids a device->host sync; clamp for safety.
if
topk_keep_max
is
None
:
k_max
=
int
(
k_eff
.
max
().
item
())
else
:
k_max
=
int
(
topk_keep_max
)
if
k_max
<
0
:
k_max
=
0
keep_mask
=
must_keep
.
clone
()
if
k_max
>
0
:
L_max
=
int
(
prompt_lens_i64
.
max
().
item
())
masked_scores
=
token_scores
.
to
(
torch
.
float32
).
masked_fill
(
must_keep
,
float
(
"-inf"
))
scores_flat
=
masked_scores
.
new_full
((
B
*
L_max
,
),
float
(
"-inf"
))
linear
=
req_ids
*
L_max
+
pos_in_req
scores_flat
[
linear
]
=
masked_scores
scores
=
scores_flat
.
view
(
B
,
L_max
)
topk_pos
=
torch
.
topk
(
scores
,
k
=
k_max
,
dim
=
1
).
indices
# [B, k_max]
col_mask
=
(
torch
.
arange
(
k_max
,
device
=
device
).
unsqueeze
(
0
)
<
k_eff
.
unsqueeze
(
1
))
global_sel
=
starts
.
unsqueeze
(
1
)
+
topk_pos
.
to
(
torch
.
long
)
# [B,k_max]
flat_idx
=
global_sel
.
reshape
(
-
1
).
clamp_
(
0
,
T
-
1
)
flat_val
=
col_mask
.
reshape
(
-
1
).
to
(
torch
.
int32
)
tmp
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
int32
)
tmp
.
scatter_add_
(
0
,
flat_idx
,
flat_val
)
keep_mask
|=
tmp
>
0
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
.
scatter_add_
(
0
,
req_ids
,
keep_mask
.
to
(
torch
.
long
))
keep_max_len
=
int
(
keep_len
.
max
().
item
())
if
B
>
0
else
0
if
keep_max_len
<=
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
keep_len
.
to
(
torch
.
int32
)
# Stable, order-preserving index list using segment-local ranks.
keep_prefix
=
torch
.
cumsum
(
keep_mask
.
to
(
torch
.
long
),
dim
=
0
)
# [T]
start_minus_1
=
(
starts
-
1
).
clamp_min
(
0
)
prefix_before_all
=
keep_prefix
.
index_select
(
0
,
start_minus_1
)
prefix_before
=
torch
.
where
(
starts
>
0
,
prefix_before_all
,
torch
.
zeros_like
(
prefix_before_all
))
prefix_before_t
=
prefix_before
.
index_select
(
0
,
req_ids
)
local_rank
=
keep_prefix
-
prefix_before_t
-
1
# [T]
idx_sorted
=
torch
.
zeros
((
B
,
keep_max_len
),
device
=
device
,
dtype
=
torch
.
int32
)
lin_out
=
(
req_ids
*
keep_max_len
+
local_rank
).
masked_select
(
keep_mask
)
vals
=
pos_in_req
.
to
(
torch
.
int32
).
masked_select
(
keep_mask
)
idx_sorted
.
view
(
-
1
).
scatter_
(
0
,
lin_out
,
vals
)
return
idx_sorted
,
keep_len
.
to
(
torch
.
int32
)
def
_compute_prompt_end_indices
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] scheduled tokens for this step
key_cache
:
torch
.
Tensor
,
# layer KV cache view (platform-dependent)
query_start_loc
:
torch
.
Tensor
,
# [B+1] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
prompt_end
:
torch
.
Tensor
,
# [B] bool
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32
topk_keep_max
:
Optional
[
int
],
sm_scale
:
float
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device
=
query
.
device
if
prompt_end
.
numel
()
==
0
:
return
None
sel
=
torch
.
nonzero
(
prompt_end
,
as_tuple
=
False
).
flatten
()
if
int
(
sel
.
numel
())
==
0
:
return
None
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
)
keep_last
=
bool
(
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
)
protected_prefix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
)
protected_suffix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
)
# Build packed Q window (last `window` queries per selected request).
sel_list
=
sel
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
qsl
=
query_start_loc
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
q_chunks
=
[]
cu_q
=
[
0
]
w_list
=
[]
for
b
in
sel_list
:
s
=
int
(
qsl
[
b
])
e
=
int
(
qsl
[
b
+
1
])
q_len
=
max
(
0
,
e
-
s
)
win
=
min
(
window
,
q_len
)
w_list
.
append
(
int
(
win
))
if
win
>
0
:
q_chunks
.
append
(
query
[
e
-
win
:
e
])
cu_q
.
append
(
cu_q
[
-
1
]
+
int
(
win
))
if
cu_q
[
-
1
]
<=
0
:
return
None
q_packed
=
torch
.
cat
(
q_chunks
,
dim
=
0
)
if
q_chunks
else
query
[:
0
]
cu_seqlens_q
=
torch
.
tensor
(
cu_q
,
device
=
device
,
dtype
=
torch
.
int32
)
w
=
torch
.
tensor
(
w_list
,
device
=
device
,
dtype
=
torch
.
int32
)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel
=
prompt_lens
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
topk_keep_sel
=
topk_keep
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
((
int
(
prompt_lens_sel
.
numel
())
+
1
,
),
device
=
device
,
dtype
=
torch
.
int32
)
if
int
(
prompt_lens_sel
.
numel
())
>
0
:
cu_seqlens_k
[
1
:]
=
torch
.
cumsum
(
prompt_lens_sel
,
dim
=
0
)
block_table_sel
=
block_table
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
if
not
current_platform
.
is_rocm
():
# CUDA cache view: [num_blocks, block_size, H, D] -> [num_blocks, H, block_size, D]
key_cache_view
=
key_cache
.
permute
(
0
,
2
,
1
,
3
)
else
:
key_cache_view
=
key_cache
from
vllm.v1.attention.kv_compression.kv_cache_triton
import
(
gather_k_to_packed_triton
)
k_packed
=
gather_k_to_packed_triton
(
key_cache_view
,
block_table_sel
,
prompt_lens_sel
,
cu_seqlens_k
,
)
# SnapKV Triton scoring (token-shared via sum over KV heads).
from
vllm.v1.attention.kv_compression.snapkv_triton
import
(
query_aware_key_scores
)
try
:
scores_per_head
=
query_aware_key_scores
(
q
=
q_packed
,
k
=
k_packed
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
w
=
w
,
sm_scale
=
float
(
sm_scale
),
pool
=
False
,
protect_last
=
False
,
normalize
=
False
,
)
token_scores
=
scores_per_head
.
sum
(
dim
=
1
)
except
Exception
:
# Fallback: PyTorch reference scoring (slow but correctness-oriented).
Hq
=
q_packed
.
shape
[
1
]
Hk
=
k_packed
.
shape
[
1
]
D
=
q_packed
.
shape
[
2
]
if
Hq
%
Hk
!=
0
:
raise
group
=
Hq
//
Hk
token_scores
=
torch
.
zeros
((
k_packed
.
shape
[
0
],
),
device
=
device
,
dtype
=
torch
.
float32
)
for
i
in
range
(
len
(
sel_list
)):
qs
=
int
(
cu_q
[
i
])
qe
=
int
(
cu_q
[
i
+
1
])
ks
=
int
(
cu_seqlens_k
[
i
].
item
())
ke
=
int
(
cu_seqlens_k
[
i
+
1
].
item
())
if
qe
<=
qs
or
ke
<=
ks
:
continue
q_win
=
q_packed
[
qs
:
qe
]
# [win, Hq, D]
q_win
=
q_win
.
reshape
(
q_win
.
shape
[
0
],
Hk
,
group
,
D
).
mean
(
dim
=
2
)
k_all
=
k_packed
[
ks
:
ke
]
qh
=
q_win
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
kh
=
k_all
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
logits
=
torch
.
matmul
(
qh
,
kh
.
transpose
(
1
,
2
))
*
float
(
sm_scale
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
token_scores
[
ks
:
ke
]
=
probs
.
sum
(
dim
=
1
).
sum
(
dim
=
0
)
from
vllm.distributed.parallel_state
import
get_tp_group
token_scores
=
get_tp_group
().
all_reduce
(
token_scores
)
idx_sorted
,
keep_len
=
_prompt_end_topk_keep_indices
(
token_scores
=
token_scores
,
prompt_lens
=
prompt_lens_sel
,
topk_keep
=
topk_keep_sel
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
topk_keep_max
=
topk_keep_max
,
)
return
{
"req_indices"
:
sel
.
to
(
torch
.
int32
),
"idx_sorted"
:
idx_sorted
,
# [B_sel, K_max] int32
"keep_len"
:
keep_len
,
# [B_sel] int32
"prompt_lens"
:
prompt_lens_sel
,
# [B_sel] int32
}
def
use_cascade_attention
(
def
use_cascade_attention
(
common_prefix_len
:
int
,
common_prefix_len
:
int
,
query_lens
:
np
.
ndarray
,
query_lens
:
np
.
ndarray
,
...
@@ -1393,225 +1005,3 @@ def cascade_attention(
...
@@ -1393,225 +1005,3 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output.
# Merge prefix and suffix outputs, and store the result in output.
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
merge_attn_states
(
output
,
prefix_output
,
prefix_lse
,
suffix_output
,
suffix_lse
)
suffix_lse
)
def
_snapkv_like_token_scores
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D]
key
:
torch
.
Tensor
,
# [T, Hkv, D]
query_start_loc
:
torch
.
Tensor
,
# [B+1]
window
:
Union
[
int
,
torch
.
Tensor
],
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""Compute token-shared SnapKV-like scores for a packed varlen batch.
Scores are computed as the attention mass from the last `window` query
tokens to the earlier keys within the same scheduled segment (per request),
summed across KV heads.
Prefers a Triton implementation when available; falls back to a (slower)
PyTorch reference implementation otherwise.
"""
global
_DISABLE_SNAPKV_TRITON
device
=
query
.
device
T
,
Hq
,
D
=
query
.
shape
Hkv
=
key
.
shape
[
1
]
if
Hq
%
Hkv
!=
0
:
raise
ValueError
(
"Query heads must be a multiple of KV heads."
)
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if
(
HAS_TRITON
and
not
_DISABLE_SNAPKV_TRITON
and
device
.
type
==
"cuda"
and
(
not
current_platform
.
is_rocm
()
or
envs
.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM
)
and
query
.
stride
(
-
1
)
==
1
and
key
.
stride
(
-
1
)
==
1
):
try
:
from
vllm.v1.attention.kv_compression.snapkv_triton
import
(
query_aware_key_scores
)
w
=
int
(
window
)
if
isinstance
(
window
,
int
)
else
window
scores_per_head
=
query_aware_key_scores
(
q
=
query
,
k
=
key
,
cu_seqlens_q
=
query_start_loc
,
cu_seqlens_k
=
query_start_loc
,
w
=
w
,
sm_scale
=
float
(
sm_scale
),
pool
=
False
,
protect_last
=
False
,
normalize
=
False
,
)
token_scores
=
scores_per_head
.
sum
(
dim
=
1
)
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
token_scores
)
except
Exception
as
e
:
_DISABLE_SNAPKV_TRITON
=
True
logger
.
warning
(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s"
,
e
)
group
=
Hq
//
Hkv
# Read boundaries on host (small tensor).
qsl
=
query_start_loc
.
tolist
()
B
=
len
(
qsl
)
-
1
wsl
=
None
if
not
isinstance
(
window
,
int
):
if
int
(
window
.
numel
())
!=
B
:
raise
ValueError
(
"window must be a scalar int or have shape [B]."
)
wsl
=
window
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
scores
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
float32
)
for
b
in
range
(
B
):
s
=
int
(
qsl
[
b
])
e
=
int
(
qsl
[
b
+
1
])
L
=
e
-
s
if
L
<=
0
:
continue
win_b
=
int
(
window
)
if
wsl
is
None
else
int
(
wsl
[
b
])
if
win_b
<=
0
:
continue
win
=
min
(
win_b
,
L
)
k_eff_end
=
L
-
win
if
k_eff_end
<=
0
:
continue
q_win
=
query
[
e
-
win
:
e
]
# [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win
=
q_win
.
reshape
(
win
,
Hkv
,
group
,
D
).
mean
(
dim
=
2
)
# [win, Hkv, D]
k_eff
=
key
[
s
:
s
+
k_eff_end
]
# [k_eff_end, Hkv, D]
qh
=
q_win
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, win, D]
kh
=
k_eff
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, k_eff_end, D]
logits
=
torch
.
matmul
(
qh
,
kh
.
transpose
(
1
,
2
))
*
sm_scale
# [Hkv, win, K]
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
# Sum over (heads, window queries) -> per-key token score.
scores
[
s
:
s
+
k_eff_end
]
=
probs
.
sum
(
dim
=
1
).
sum
(
dim
=
0
)
# Aggregate across tensor-parallel ranks so every rank selects the same
# token indices.
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
scores
)
def
_topk_kv_compact_slot_mapping
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
max_query_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
or
B
==
0
:
return
dst_slots
# Per-request segment boundaries in the packed [T] layout.
# NOTE: `query_start_loc` is already sliced to [B+1] by the model runner.
starts
=
query_start_loc
[:
B
].
to
(
torch
.
long
)
ends
=
query_start_loc
[
1
:
B
+
1
].
to
(
torch
.
long
)
lengths
=
ends
-
starts
# [B]
if
lengths
.
numel
()
==
0
:
return
dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max
=
int
(
max_query_len
)
if
max_query_len
is
not
None
else
int
(
lengths
.
max
().
item
())
if
L_max
<=
0
:
return
dst_slots
# Map each token to its (request, offset-within-request) coordinate.
token_idx
=
torch
.
arange
(
T
,
device
=
device
,
dtype
=
torch
.
long
)
# For monotonic `ends` (cu_seqlens), this returns the request id for each
# token in the packed layout.
# Use right=True so that idx==ends[b] maps to the *next* request (b+1),
# i.e., request segments are [start, end) in the packed layout.
req_ids
=
torch
.
bucketize
(
token_idx
,
ends
,
right
=
True
)
# [T]
start_per_token
=
starts
.
index_select
(
0
,
req_ids
)
pos_in_req
=
token_idx
-
start_per_token
# [T] in [0, L_b)
# Clamp the per-request top-k budget to the number of candidate tokens
# (excluding must_keep).
must_keep_counts
=
torch
.
zeros
(
B
,
device
=
device
,
dtype
=
torch
.
long
)
must_keep_counts
.
scatter_add_
(
0
,
req_ids
,
must_keep
.
to
(
torch
.
long
))
cand_counts
=
(
lengths
.
to
(
torch
.
long
)
-
must_keep_counts
).
clamp_min
(
0
)
k_eff
=
torch
.
minimum
(
topk_budget
.
to
(
torch
.
long
).
clamp_min
(
0
),
cand_counts
)
# Prefer an upper bound from CPU (piecewise graph), to avoid sync.
if
topk_budget_max
is
not
None
:
k_max
=
min
(
int
(
topk_budget_max
),
L_max
)
else
:
k_max
=
int
(
k_eff
.
max
().
item
())
# Build a padded [B, L_max] score matrix for a single batched Top-K call.
# Must-keep and padding positions are set to -inf to avoid selection.
keep_mask
=
must_keep
.
clone
()
if
k_max
>
0
:
if
token_scores
is
None
:
raise
ValueError
(
"token_scores must be provided when k_max > 0."
)
masked_scores
=
token_scores
.
to
(
dtype
=
torch
.
float32
).
masked_fill
(
must_keep
,
float
(
"-inf"
))
scores_flat
=
masked_scores
.
new_full
((
B
*
L_max
,
),
float
(
"-inf"
))
linear
=
req_ids
*
L_max
+
pos_in_req
scores_flat
[
linear
]
=
masked_scores
scores
=
scores_flat
.
view
(
B
,
L_max
)
topk_pos
=
torch
.
topk
(
scores
,
k
=
k_max
,
dim
=
1
).
indices
# [B, k_max]
# Select only the first k_eff[b] entries for each request b.
col_mask
=
torch
.
arange
(
k_max
,
device
=
device
).
unsqueeze
(
0
)
<
k_eff
.
unsqueeze
(
1
)
# [B, k_max]
# Avoid host-side synchronization from dynamic indexing. Instead, mark
# selected tokens via a fixed-size scatter-add.
global_sel
=
starts
.
unsqueeze
(
1
)
+
topk_pos
.
to
(
torch
.
long
)
# [B, k_max]
flat_idx
=
global_sel
.
reshape
(
-
1
).
clamp_
(
0
,
T
-
1
)
flat_val
=
col_mask
.
reshape
(
-
1
).
to
(
torch
.
int32
)
tmp
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
int32
)
tmp
.
scatter_add_
(
0
,
flat_idx
,
flat_val
)
keep_mask
|=
tmp
>
0
# Compute segment-local ranks (0..kept-1) for kept tokens, preserving token
# order within each request, without dynamic indexing (graph-friendly).
keep_prefix
=
torch
.
cumsum
(
keep_mask
.
to
(
torch
.
long
),
dim
=
0
)
# [T]
start_minus_1
=
(
starts
-
1
).
clamp_min
(
0
)
prefix_before_all
=
keep_prefix
.
index_select
(
0
,
start_minus_1
.
to
(
torch
.
long
))
prefix_before
=
torch
.
where
(
starts
>
0
,
prefix_before_all
,
torch
.
zeros_like
(
prefix_before_all
))
# [B]
prefix_before_per_token
=
prefix_before
.
index_select
(
0
,
req_ids
)
# [T]
local_rank
=
keep_prefix
-
prefix_before_per_token
-
1
# [T]
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv
=
(
seq_lens
[:
B
].
to
(
torch
.
long
)
-
lengths
.
to
(
torch
.
long
)).
clamp_min
(
0
)
base_kv_per_token
=
base_kv
.
index_select
(
0
,
req_ids
)
# [T]
dest_pos
=
base_kv_per_token
+
local_rank
# [T]
dest_block_idx
=
dest_pos
//
block_size
dest_off
=
dest_pos
-
dest_block_idx
*
block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks
=
int
(
block_table
.
shape
[
1
])
dest_block_idx_safe
=
dest_block_idx
.
clamp_
(
0
,
max_blocks
-
1
).
to
(
torch
.
long
)
block_nums
=
block_table
[
req_ids
,
dest_block_idx_safe
]
dest_slot
=
block_nums
.
to
(
torch
.
long
)
*
block_size
+
dest_off
dst_slots
=
torch
.
where
(
keep_mask
,
dest_slot
.
to
(
torch
.
int64
),
dst_slots
)
return
dst_slots
vllm/v1/kv_compression/compaction_step.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
torch
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.slot_mapping
import
topk_kv_compact_slot_mapping
from
vllm.v1.kv_compression.snapkv_score
import
snapkv_like_token_scores
def
snapkv_window_for_topk_budget
(
*
,
topk_budget
:
torch
.
Tensor
,
# [B] int32
window
:
int
,
)
->
torch
.
Tensor
:
"""Build per-request SnapKV window sizes for mixed batches.
Requests with a zero Top-K budget do not need token scores; setting their
window to 0 lets the Triton scoring kernel early-return.
"""
return
torch
.
where
(
topk_budget
>
0
,
torch
.
full_like
(
topk_budget
,
int
(
window
)),
torch
.
zeros_like
(
topk_budget
),
)
def
compute_compact_dst_slots_for_step
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] for this step
key
:
torch
.
Tensor
,
# [T, Hkv, D] for this step
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
topk_budget_max
:
int
,
max_query_len
:
int
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""Compute per-token KV compaction destinations for one step."""
token_scores
=
None
if
int
(
topk_budget_max
)
>
0
:
w
=
snapkv_window_for_topk_budget
(
topk_budget
=
topk_budget
,
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
),
)
token_scores
=
snapkv_like_token_scores
(
query
=
query
,
key
=
key
,
query_start_loc
=
query_start_loc
,
window
=
w
,
sm_scale
=
float
(
sm_scale
),
)
return
topk_kv_compact_slot_mapping
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
block_size
=
int
(
block_size
),
max_query_len
=
int
(
max_query_len
),
topk_budget_max
=
int
(
topk_budget_max
),
)
vllm/v1/kv_compression/flash_attn_hooks.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Protocol
import
torch
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.platforms
import
current_platform
from
vllm.v1.kv_compression.compaction_step
import
compute_compact_dst_slots_for_step
from
vllm.v1.kv_compression.forward_context
import
(
get_kv_compression_compact_slots
,
get_kv_compression_prompt_payload
,
set_kv_compression_compact_slots
,
set_kv_compression_prompt_payload
,
)
from
vllm.v1.kv_compression.prompt_end
import
compute_prompt_end_indices
from
vllm.v1.kv_compression.slot_mapping
import
kv_compaction_dst_rewrite_mapping
class
_ReshapeAndCacheFn
(
Protocol
):
def
__call__
(
self
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
None
:
...
def
maybe_compute_prompt_end_payload_flash_attn
(
*
,
kv_sharing_target_layer_name
:
Optional
[
str
],
query
:
torch
.
Tensor
,
num_actual_tokens
:
int
,
key_cache
:
torch
.
Tensor
,
cache_block_size
:
int
,
attn_metadata
:
Any
,
sm_scale
:
float
,
)
->
None
:
"""Compute and stash prompt-end Top-K indices for chunked-prefill scheme 3.
The payload is cached in the forward context and later consumed by the
model runner to perform one-shot prompt KV compaction before the first
decode step.
"""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
kv_sharing_target_layer_name
is
not
None
:
return
prompt_end
=
getattr
(
attn_metadata
,
"kv_compression_prompt_end"
,
None
)
prompt_lens
=
getattr
(
attn_metadata
,
"kv_compression_prompt_lens"
,
None
)
topk_keep
=
getattr
(
attn_metadata
,
"kv_compression_prompt_topk_keep"
,
None
)
if
prompt_end
is
None
or
prompt_lens
is
None
or
topk_keep
is
None
:
return
forward_context
=
get_forward_context
()
if
get_kv_compression_prompt_payload
(
forward_context
)
is
not
None
:
return
payload
=
compute_prompt_end_indices
(
query
=
query
[:
num_actual_tokens
],
key_cache
=
key_cache
,
block_size
=
cache_block_size
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
block_table
=
attn_metadata
.
block_table
,
prompt_end
=
prompt_end
,
prompt_lens
=
prompt_lens
,
topk_keep
=
topk_keep
,
topk_keep_max
=
getattr
(
attn_metadata
,
"kv_compression_prompt_topk_keep_max"
,
None
),
sm_scale
=
sm_scale
,
)
if
payload
is
not
None
:
set_kv_compression_prompt_payload
(
forward_context
,
payload
)
def
maybe_compact_kv_cache_flash_attn
(
*
,
kv_sharing_target_layer_name
:
Optional
[
str
],
layer
:
Any
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_actual_tokens
:
int
,
cache_block_size
:
int
,
attn_metadata
:
Any
,
sm_scale
:
float
,
kv_cache_dtype
:
str
,
reshape_and_cache
:
_ReshapeAndCacheFn
,
)
->
None
:
"""Optional per-step KV compaction for scheme 1/2 token-shared selection."""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
kv_sharing_target_layer_name
is
not
None
:
return
must_keep
=
getattr
(
attn_metadata
,
"kv_compression_must_keep"
,
None
)
topk_budget
=
getattr
(
attn_metadata
,
"kv_compression_topk_budget"
,
None
)
if
must_keep
is
None
or
topk_budget
is
None
:
return
forward_context
=
get_forward_context
()
per_layer_topk
=
envs
.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER
dst
=
get_kv_compression_compact_slots
(
forward_context
,
per_layer_topk
=
per_layer_topk
,
layer
=
layer
,
)
if
dst
is
None
:
topk_budget_max
=
int
(
getattr
(
attn_metadata
,
"kv_compression_topk_budget_max"
,
0
)
or
0
)
dst
=
compute_compact_dst_slots_for_step
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
block_table
=
attn_metadata
.
block_table
,
block_size
=
cache_block_size
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
topk_budget_max
=
topk_budget_max
,
max_query_len
=
attn_metadata
.
max_query_len
,
sm_scale
=
sm_scale
,
)
set_kv_compression_compact_slots
(
forward_context
,
per_layer_topk
=
per_layer_topk
,
layer
=
layer
,
dst
=
dst
,
)
if
dst
is
None
:
return
src
=
attn_metadata
.
slot_mapping
dst_rewrite
=
kv_compaction_dst_rewrite_mapping
(
dst_slots
=
dst
,
src_slots
=
src
)
if
not
current_platform
.
is_rocm
():
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
return
# ROCm: optionally prefer the optimized reshape-and-cache kernel.
if
(
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
and
key
.
dtype
==
torch
.
float16
):
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
vllm/v1/kv_compression/forward_context.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
import
torch
_PROMPT_PAYLOAD_ATTR
=
"_kv_compression_prompt_payload"
_COMPACT_SLOTS_ATTR
=
"_kv_compression_compact_slots"
_COMPACT_SLOTS_BY_LAYER_ATTR
=
"_kv_compression_compact_slots_by_layer"
def
get_kv_compression_prompt_payload
(
forward_context
:
Any
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
return
getattr
(
forward_context
,
_PROMPT_PAYLOAD_ATTR
,
None
)
def
set_kv_compression_prompt_payload
(
forward_context
:
Any
,
payload
:
dict
[
str
,
torch
.
Tensor
],
)
->
None
:
setattr
(
forward_context
,
_PROMPT_PAYLOAD_ATTR
,
payload
)
def
_kv_compression_layer_key
(
layer
:
Any
)
->
str
:
layer_name
=
getattr
(
layer
,
"layer_name"
,
None
)
if
layer_name
is
None
:
layer_name
=
str
(
id
(
layer
))
return
str
(
layer_name
)
def
get_kv_compression_compact_slots
(
forward_context
:
Any
,
*
,
per_layer_topk
:
bool
,
layer
:
Any
,
)
->
Optional
[
torch
.
Tensor
]:
if
per_layer_topk
:
dst_by_layer
=
getattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
None
)
if
dst_by_layer
is
None
:
return
None
return
dst_by_layer
.
get
(
_kv_compression_layer_key
(
layer
))
return
getattr
(
forward_context
,
_COMPACT_SLOTS_ATTR
,
None
)
def
set_kv_compression_compact_slots
(
forward_context
:
Any
,
*
,
per_layer_topk
:
bool
,
layer
:
Any
,
dst
:
torch
.
Tensor
,
)
->
None
:
if
per_layer_topk
:
dst_by_layer
=
getattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
None
)
if
dst_by_layer
is
None
:
dst_by_layer
=
{}
setattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
dst_by_layer
)
dst_by_layer
[
_kv_compression_layer_key
(
layer
)]
=
dst
else
:
setattr
(
forward_context
,
_COMPACT_SLOTS_ATTR
,
dst
)
vllm/v1/kv_compression/kv_cache_view.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
torch
from
vllm.platforms
import
current_platform
def
paged_k_cache_view_for_triton_gather
(
*
,
key_cache
:
torch
.
Tensor
,
block_size
:
int
,
)
->
torch
.
Tensor
:
"""Return a KV-cache key view in [num_blocks, H, block_size, D] layout.
Supports both:
- [num_blocks, block_size, H, D] (typical CUDA FlashAttention v1 layout)
- [num_blocks, H, block_size, D] (ROCm FlashAttention v1, or external
connectors that expose the cache in HND shape)
"""
if
key_cache
.
ndim
!=
4
:
raise
ValueError
(
"key_cache must be a 4D tensor."
)
# Common case: [B, T, H, D] -> [B, H, T, D]
if
int
(
key_cache
.
shape
[
1
])
==
int
(
block_size
):
return
key_cache
.
permute
(
0
,
2
,
1
,
3
)
# Already in [B, H, T, D] (ROCm / HND-shaped external caches).
if
int
(
key_cache
.
shape
[
2
])
==
int
(
block_size
):
return
key_cache
# Fallback: preserve historical behavior.
if
current_platform
.
is_rocm
():
return
key_cache
return
key_cache
.
permute
(
0
,
2
,
1
,
3
)
vllm/v1/kv_compression/metadata.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
torch
import
vllm.envs
as
envs
@
dataclass
class
KVCompressionAttentionMetadata
:
"""Per-batch KV compression metadata consumed by attention backends."""
must_keep
:
Optional
[
torch
.
Tensor
]
=
None
topk_budget
:
Optional
[
torch
.
Tensor
]
=
None
topk_budget_max
:
Optional
[
int
]
=
None
prompt_end
:
Optional
[
torch
.
Tensor
]
=
None
prompt_lens
:
Optional
[
torch
.
Tensor
]
=
None
prompt_topk_keep
:
Optional
[
torch
.
Tensor
]
=
None
prompt_topk_keep_max
:
Optional
[
int
]
=
None
def
build_kv_compression_attn_metadata
(
*
,
runner
:
Any
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
)
->
KVCompressionAttentionMetadata
:
"""Build KV compression metadata for one attention step.
This helper keeps backend code thin and centralizes the logic for selecting
between per-step compaction (scheme 1/2) and prompt-end one-shot scoring
(scheme 3).
"""
meta
=
KVCompressionAttentionMetadata
()
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
return
meta
# Scheme 1/2: compute compaction destinations every step.
if
getattr
(
runner
,
"kv_compression_needs_compaction"
,
False
):
meta
.
must_keep
=
runner
.
kv_compression_must_keep
[:
num_actual_tokens
]
meta
.
topk_budget
=
runner
.
kv_compression_topk_budget
[:
num_reqs
]
# Avoid device->host sync by reading from the CPU staging buffer.
if
num_reqs
>
0
:
meta
.
topk_budget_max
=
int
(
runner
.
kv_compression_topk_budget_np
[:
num_reqs
].
max
())
else
:
meta
.
topk_budget_max
=
0
return
meta
# Scheme 3: compute global prompt indices only on the last prefill chunk,
# and perform the actual cache compaction before the first decode step.
scheduler_config
=
getattr
(
runner
,
"scheduler_config"
,
None
)
if
scheduler_config
is
None
or
not
getattr
(
scheduler_config
,
"chunked_prefill_enabled"
,
False
):
return
meta
if
num_reqs
<=
0
:
return
meta
if
not
runner
.
kv_compression_prompt_end_np
[:
num_reqs
].
any
():
return
meta
meta
.
prompt_end
=
runner
.
kv_compression_prompt_end
[:
num_reqs
]
meta
.
prompt_lens
=
runner
.
kv_compression_prompt_lens
[:
num_reqs
]
meta
.
prompt_topk_keep
=
runner
.
kv_compression_prompt_topk_keep
[:
num_reqs
]
meta
.
prompt_topk_keep_max
=
int
(
getattr
(
runner
,
"kv_compression_prompt_topk_keep_max"
,
0
)
or
0
)
return
meta
vllm/v1/kv_compression/prompt_end.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.kv_cache_view
import
paged_k_cache_view_for_triton_gather
from
vllm.v1.kv_compression.snapkv_score
import
snapkv_query_aware_token_scores
from
vllm.v1.kv_compression.topk_select
import
(
_packed_varlen_coords
,
_topk_keep_mask_and_local_rank
)
def
_prompt_end_topk_keep_indices
(
*
,
token_scores
:
torch
.
Tensor
,
# [T] float32
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32 (candidates only)
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
topk_keep_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device
=
token_scores
.
device
B
=
int
(
prompt_lens
.
numel
())
if
B
==
0
:
empty
=
torch
.
empty
((
0
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
int32
)
prompt_lens_i64
=
prompt_lens
.
to
(
torch
.
long
)
cu
=
torch
.
zeros
((
B
+
1
,
),
device
=
device
,
dtype
=
torch
.
long
)
cu
[
1
:]
=
torch
.
cumsum
(
prompt_lens_i64
,
dim
=
0
)
T
=
int
(
token_scores
.
numel
())
if
T
==
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
starts
,
_
,
lengths
,
req_ids
,
pos_in_req
=
_packed_varlen_coords
(
cu_seqlens
=
cu
,
total_tokens
=
T
,
)
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_prefix
,
0
))
suffix
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_suffix
,
0
))
suffix_start
=
(
prompt_lens_i64
-
suffix
).
clamp_min
(
0
)
prefix_len_t
=
prefix_len
.
index_select
(
0
,
req_ids
)
suffix_start_t
=
suffix_start
.
index_select
(
0
,
req_ids
)
must_keep
=
(
pos_in_req
<
prefix_len_t
)
|
(
pos_in_req
>=
suffix_start_t
)
if
keep_last_token
:
last
=
(
prompt_lens_i64
-
1
).
clamp_min
(
0
)
last_t
=
last
.
index_select
(
0
,
req_ids
)
must_keep
|=
pos_in_req
==
last_t
keep_mask
,
local_rank
,
keep_len
=
_topk_keep_mask_and_local_rank
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_keep
,
starts
=
starts
,
lengths
=
lengths
,
req_ids
=
req_ids
,
pos_in_req
=
pos_in_req
,
max_len
=
int
(
prompt_lens_i64
.
max
().
item
()),
topk_budget_max
=
topk_keep_max
,
)
keep_max_len
=
int
(
keep_len
.
max
().
item
())
if
B
>
0
else
0
if
keep_max_len
<=
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
keep_len
idx_sorted
=
torch
.
zeros
((
B
,
keep_max_len
),
device
=
device
,
dtype
=
torch
.
int32
)
lin_out
=
(
req_ids
*
keep_max_len
+
local_rank
).
masked_select
(
keep_mask
)
vals
=
pos_in_req
.
to
(
torch
.
int32
).
masked_select
(
keep_mask
)
idx_sorted
.
view
(
-
1
).
scatter_
(
0
,
lin_out
,
vals
)
return
idx_sorted
,
keep_len
def
compute_prompt_end_indices
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] scheduled tokens for this step
key_cache
:
torch
.
Tensor
,
# layer KV cache view (platform-dependent)
block_size
:
int
,
query_start_loc
:
torch
.
Tensor
,
# [B+1] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
prompt_end
:
torch
.
Tensor
,
# [B] bool
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32
topk_keep_max
:
Optional
[
int
],
sm_scale
:
float
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device
=
query
.
device
if
prompt_end
.
numel
()
==
0
:
return
None
sel
=
torch
.
nonzero
(
prompt_end
,
as_tuple
=
False
).
flatten
()
if
int
(
sel
.
numel
())
==
0
:
return
None
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
)
keep_last
=
bool
(
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
)
protected_prefix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
)
protected_suffix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
)
# Build packed Q window (last `window` queries per selected request).
sel_list
=
sel
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
qsl
=
query_start_loc
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
q_chunks
=
[]
cu_q
=
[
0
]
w_list
=
[]
for
b
in
sel_list
:
s
=
int
(
qsl
[
b
])
e
=
int
(
qsl
[
b
+
1
])
q_len
=
max
(
0
,
e
-
s
)
win
=
min
(
window
,
q_len
)
w_list
.
append
(
int
(
win
))
if
win
>
0
:
q_chunks
.
append
(
query
[
e
-
win
:
e
])
cu_q
.
append
(
cu_q
[
-
1
]
+
int
(
win
))
if
cu_q
[
-
1
]
<=
0
:
return
None
q_packed
=
torch
.
cat
(
q_chunks
,
dim
=
0
)
if
q_chunks
else
query
[:
0
]
cu_seqlens_q
=
torch
.
tensor
(
cu_q
,
device
=
device
,
dtype
=
torch
.
int32
)
w
=
torch
.
tensor
(
w_list
,
device
=
device
,
dtype
=
torch
.
int32
)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel
=
prompt_lens
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
topk_keep_sel
=
topk_keep
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
((
int
(
prompt_lens_sel
.
numel
())
+
1
,
),
device
=
device
,
dtype
=
torch
.
int32
)
if
int
(
prompt_lens_sel
.
numel
())
>
0
:
cu_seqlens_k
[
1
:]
=
torch
.
cumsum
(
prompt_lens_sel
,
dim
=
0
)
block_table_sel
=
block_table
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
key_cache_view
=
paged_k_cache_view_for_triton_gather
(
key_cache
=
key_cache
,
block_size
=
int
(
block_size
),
)
from
vllm.v1.kv_compression.kv_cache_triton
import
(
gather_k_to_packed_triton
)
k_packed
=
gather_k_to_packed_triton
(
key_cache_view
,
block_table_sel
,
prompt_lens_sel
,
cu_seqlens_k
,
)
token_scores
=
snapkv_query_aware_token_scores
(
query
=
q_packed
,
key
=
k_packed
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
window
=
w
,
sm_scale
=
float
(
sm_scale
),
)
idx_sorted
,
keep_len
=
_prompt_end_topk_keep_indices
(
token_scores
=
token_scores
,
prompt_lens
=
prompt_lens_sel
,
topk_keep
=
topk_keep_sel
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
topk_keep_max
=
topk_keep_max
,
)
return
{
"req_indices"
:
sel
.
to
(
torch
.
int32
),
"idx_sorted"
:
idx_sorted
,
# [B_sel, K_max] int32
"keep_len"
:
keep_len
,
# [B_sel] int32
"prompt_lens"
:
prompt_lens_sel
,
# [B_sel] int32
}
vllm/v1/kv_compression/slot_mapping.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
from
vllm.v1.kv_compression.topk_select
import
(
_packed_varlen_coords
,
_topk_keep_mask_and_local_rank
)
def
_dst_slots_from_keep_mask_and_local_rank
(
*
,
keep_mask
:
torch
.
Tensor
,
# [T] bool
local_rank
:
torch
.
Tensor
,
# [T] int64
seq_lens
:
torch
.
Tensor
,
# [B] int32
lengths
:
torch
.
Tensor
,
# [B] int64
req_ids
:
torch
.
Tensor
,
# [T] int64
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
block_size
:
int
,
)
->
torch
.
Tensor
:
"""Convert keep_mask/local_rank into a per-token KV destination slot."""
device
=
keep_mask
.
device
T
=
int
(
keep_mask
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
:
return
dst_slots
B
=
int
(
seq_lens
.
numel
())
if
B
==
0
:
return
dst_slots
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv
=
(
seq_lens
[:
B
].
to
(
torch
.
long
)
-
lengths
.
to
(
torch
.
long
)).
clamp_min
(
0
)
base_kv_per_token
=
base_kv
.
index_select
(
0
,
req_ids
)
# [T]
dest_pos
=
base_kv_per_token
+
local_rank
# [T]
dest_block_idx
=
dest_pos
//
block_size
dest_off
=
dest_pos
-
dest_block_idx
*
block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks
=
int
(
block_table
.
shape
[
1
])
dest_block_idx_safe
=
dest_block_idx
.
clamp_
(
0
,
max_blocks
-
1
).
to
(
torch
.
long
)
block_nums
=
block_table
[
req_ids
,
dest_block_idx_safe
]
dest_slot
=
block_nums
.
to
(
torch
.
long
)
*
block_size
+
dest_off
return
torch
.
where
(
keep_mask
,
dest_slot
.
to
(
torch
.
int64
),
dst_slots
)
def
topk_kv_compact_slot_mapping
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
max_query_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
or
B
==
0
:
return
dst_slots
starts
,
_
,
lengths
,
req_ids
,
pos_in_req
=
_packed_varlen_coords
(
cu_seqlens
=
query_start_loc
,
total_tokens
=
T
,
)
if
lengths
.
numel
()
==
0
:
return
dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max
=
int
(
max_query_len
)
if
max_query_len
is
not
None
else
int
(
lengths
.
max
().
item
())
if
L_max
<=
0
:
return
dst_slots
keep_mask
,
local_rank
,
_
=
_topk_keep_mask_and_local_rank
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
starts
=
starts
,
lengths
=
lengths
,
req_ids
=
req_ids
,
pos_in_req
=
pos_in_req
,
max_len
=
L_max
,
topk_budget_max
=
topk_budget_max
,
)
return
_dst_slots_from_keep_mask_and_local_rank
(
keep_mask
=
keep_mask
,
local_rank
=
local_rank
,
seq_lens
=
seq_lens
[:
B
],
lengths
=
lengths
,
req_ids
=
req_ids
,
block_table
=
block_table
,
block_size
=
int
(
block_size
),
)
def
kv_compaction_dst_rewrite_mapping
(
*
,
dst_slots
:
torch
.
Tensor
,
# [T] int64
src_slots
:
torch
.
Tensor
,
# [T] int64
)
->
torch
.
Tensor
:
"""Filter a dst slot mapping so only moved kept tokens are rewritten.
Non-rewrite tokens are marked as -1, which the cache kernels treat as
padding and skip.
"""
rewrite_mask
=
(
dst_slots
>=
0
)
&
(
dst_slots
!=
src_slots
)
return
torch
.
where
(
rewrite_mask
,
dst_slots
,
-
1
)
vllm/v1/kv_compression/snapkv_score.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Union
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
logger
=
init_logger
(
__name__
)
_DISABLE_SNAPKV_TRITON
:
bool
=
False
def
snapkv_query_aware_token_scores
(
*
,
query
:
torch
.
Tensor
,
# [N_q, Hq, D]
key
:
torch
.
Tensor
,
# [N_k, Hkv, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1]
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1]
window
:
Union
[
int
,
torch
.
Tensor
],
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""Compute token-shared SnapKV scores for packed, varlen q/k inputs.
Returns a [N_k] float32 tensor, reduced across TP ranks so every rank makes
an identical Top-K selection.
"""
global
_DISABLE_SNAPKV_TRITON
device
=
query
.
device
if
query
.
ndim
!=
3
or
key
.
ndim
!=
3
:
raise
ValueError
(
"query and key must be 3D tensors."
)
_
,
Hq
,
D
=
query
.
shape
N_k
,
Hkv
,
Dk
=
key
.
shape
if
D
!=
Dk
:
raise
ValueError
(
"query and key must have the same head size."
)
if
Hq
%
Hkv
!=
0
:
raise
ValueError
(
"Query heads must be a multiple of KV heads."
)
if
cu_seqlens_q
.
numel
()
!=
cu_seqlens_k
.
numel
():
raise
ValueError
(
"cu_seqlens_q and cu_seqlens_k must match."
)
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if
(
HAS_TRITON
and
not
_DISABLE_SNAPKV_TRITON
and
device
.
type
==
"cuda"
and
(
not
current_platform
.
is_rocm
()
or
envs
.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM
)
and
query
.
stride
(
-
1
)
==
1
and
key
.
stride
(
-
1
)
==
1
):
try
:
from
vllm.v1.kv_compression.snapkv_triton
import
(
query_aware_key_scores
)
w
=
int
(
window
)
if
isinstance
(
window
,
int
)
else
window
scores_per_head
=
query_aware_key_scores
(
q
=
query
,
k
=
key
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
w
=
w
,
sm_scale
=
float
(
sm_scale
),
pool
=
False
,
protect_last
=
False
,
normalize
=
False
,
)
token_scores
=
scores_per_head
.
sum
(
dim
=
1
)
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
token_scores
)
except
Exception
as
e
:
_DISABLE_SNAPKV_TRITON
=
True
logger
.
warning
(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s"
,
e
)
group
=
Hq
//
Hkv
# Read boundaries and (optional) per-request window sizes on host.
qsl
=
cu_seqlens_q
.
tolist
()
ksl
=
cu_seqlens_k
.
tolist
()
B
=
len
(
qsl
)
-
1
if
len
(
ksl
)
-
1
!=
B
:
raise
ValueError
(
"cu_seqlens_q and cu_seqlens_k must match."
)
wsl
=
None
if
not
isinstance
(
window
,
int
):
if
int
(
window
.
numel
())
!=
B
:
raise
ValueError
(
"window must be a scalar int or have shape [B]."
)
wsl
=
window
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
scores
=
torch
.
zeros
((
N_k
,
),
device
=
device
,
dtype
=
torch
.
float32
)
for
b
in
range
(
B
):
qs
=
int
(
qsl
[
b
])
qe
=
int
(
qsl
[
b
+
1
])
ks
=
int
(
ksl
[
b
])
ke
=
int
(
ksl
[
b
+
1
])
q_len
=
qe
-
qs
k_len
=
ke
-
ks
if
q_len
<=
0
or
k_len
<=
0
:
continue
win_b
=
int
(
window
)
if
wsl
is
None
else
int
(
wsl
[
b
])
if
win_b
<=
0
:
continue
win
=
min
(
win_b
,
q_len
,
k_len
)
k_eff_end
=
ke
-
win
if
k_eff_end
<=
ks
:
continue
q_win
=
query
[
qe
-
win
:
qe
]
# [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win
=
q_win
.
reshape
(
win
,
Hkv
,
group
,
D
).
mean
(
dim
=
2
)
# [win, Hkv, D]
k_eff
=
key
[
ks
:
k_eff_end
]
# [K_eff, Hkv, D]
qh
=
q_win
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, win, D]
kh
=
k_eff
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, K_eff, D]
logits
=
torch
.
matmul
(
qh
,
kh
.
transpose
(
1
,
2
))
*
float
(
sm_scale
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
scores
[
ks
:
k_eff_end
]
=
probs
.
sum
(
dim
=
1
).
sum
(
dim
=
0
)
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
scores
)
def
snapkv_like_token_scores
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D]
key
:
torch
.
Tensor
,
# [T, Hkv, D]
query_start_loc
:
torch
.
Tensor
,
# [B+1]
window
:
Union
[
int
,
torch
.
Tensor
],
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""SnapKV-like token scores when q/k share the same packed layout."""
return
snapkv_query_aware_token_scores
(
query
=
query
,
key
=
key
,
cu_seqlens_q
=
query_start_loc
,
cu_seqlens_k
=
query_start_loc
,
window
=
window
,
sm_scale
=
sm_scale
,
)
vllm/v1/kv_compression/topk_select.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
def
_packed_varlen_coords
(
*
,
cu_seqlens
:
torch
.
Tensor
,
# [B+1]
total_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute packed varlen segment coordinates.
Returns:
starts: [B] int64, segment start offsets (inclusive)
ends: [B] int64, segment end offsets (exclusive)
lengths: [B] int64, segment lengths (ends - starts)
req_ids: [T] int64, request id for each token in packed [0, T)
pos_in_req: [T] int64, position within its request segment
"""
device
=
cu_seqlens
.
device
B
=
int
(
cu_seqlens
.
numel
()
-
1
)
if
B
<=
0
:
empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
t_empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
return
empty
,
empty
,
empty
,
t_empty
,
t_empty
starts
=
cu_seqlens
[:
B
].
to
(
torch
.
long
)
ends
=
cu_seqlens
[
1
:
B
+
1
].
to
(
torch
.
long
)
lengths
=
ends
-
starts
if
total_tokens
<=
0
:
t_empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
return
starts
,
ends
,
lengths
,
t_empty
,
t_empty
token_idx
=
torch
.
arange
(
total_tokens
,
device
=
device
,
dtype
=
torch
.
long
)
req_ids
=
torch
.
bucketize
(
token_idx
,
ends
,
right
=
True
)
# [T]
start_per_token
=
starts
.
index_select
(
0
,
req_ids
)
pos_in_req
=
token_idx
-
start_per_token
return
starts
,
ends
,
lengths
,
req_ids
,
pos_in_req
def
_topk_keep_mask_and_local_rank
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
starts
:
torch
.
Tensor
,
# [B] int64
lengths
:
torch
.
Tensor
,
# [B] int64
req_ids
:
torch
.
Tensor
,
# [T] int64
pos_in_req
:
torch
.
Tensor
,
# [T] int64
max_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute keep_mask/local_rank for token-shared Top-K selection.
Returns:
keep_mask: [T] bool, selected tokens (includes must_keep)
local_rank: [T] int64, rank among kept tokens within each request
keep_len: [B] int32, number of kept tokens per request
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
keep_mask
=
must_keep
.
clone
()
if
T
==
0
or
B
==
0
:
local_rank
=
torch
.
empty
((
T
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
return
keep_mask
,
local_rank
,
keep_len
if
max_len
is
None
:
L_max
=
int
(
lengths
.
max
().
item
())
if
lengths
.
numel
()
>
0
else
0
else
:
L_max
=
int
(
max_len
)
if
L_max
<
0
:
L_max
=
0
must_keep_counts
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
must_keep_counts
.
scatter_add_
(
0
,
req_ids
,
must_keep
.
to
(
torch
.
long
))
cand_counts
=
(
lengths
.
to
(
torch
.
long
)
-
must_keep_counts
).
clamp_min
(
0
)
k_eff
=
torch
.
minimum
(
topk_budget
.
to
(
torch
.
long
).
clamp_min
(
0
),
cand_counts
)
# CPU-known bound avoids a device->host sync; clamp for safety.
if
topk_budget_max
is
None
:
k_max
=
int
(
k_eff
.
max
().
item
())
if
k_eff
.
numel
()
>
0
else
0
else
:
k_max
=
int
(
topk_budget_max
)
if
k_max
<
0
:
k_max
=
0
if
k_max
>
L_max
:
k_max
=
L_max
if
k_max
>
0
:
if
token_scores
is
None
:
raise
ValueError
(
"token_scores must be provided when k_max > 0."
)
masked_scores
=
token_scores
.
to
(
torch
.
float32
).
masked_fill
(
must_keep
,
float
(
"-inf"
))
scores_flat
=
masked_scores
.
new_full
((
B
*
L_max
,
),
float
(
"-inf"
))
linear
=
req_ids
*
L_max
+
pos_in_req
scores_flat
[
linear
]
=
masked_scores
scores
=
scores_flat
.
view
(
B
,
L_max
)
topk_pos
=
torch
.
topk
(
scores
,
k
=
k_max
,
dim
=
1
).
indices
# [B, k_max]
col_mask
=
torch
.
arange
(
k_max
,
device
=
device
).
unsqueeze
(
0
)
<
k_eff
.
unsqueeze
(
1
)
global_sel
=
starts
.
unsqueeze
(
1
)
+
topk_pos
.
to
(
torch
.
long
)
# [B,k_max]
flat_idx
=
global_sel
.
reshape
(
-
1
).
clamp_
(
0
,
T
-
1
)
flat_val
=
col_mask
.
reshape
(
-
1
).
to
(
torch
.
int32
)
tmp
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
int32
)
tmp
.
scatter_add_
(
0
,
flat_idx
,
flat_val
)
keep_mask
|=
tmp
>
0
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
.
scatter_add_
(
0
,
req_ids
,
keep_mask
.
to
(
torch
.
long
))
# Stable, order-preserving local rank using segment-local prefix sums.
keep_prefix
=
torch
.
cumsum
(
keep_mask
.
to
(
torch
.
long
),
dim
=
0
)
# [T]
start_minus_1
=
(
starts
-
1
).
clamp_min
(
0
)
prefix_before_all
=
keep_prefix
.
index_select
(
0
,
start_minus_1
)
prefix_before
=
torch
.
where
(
starts
>
0
,
prefix_before_all
,
torch
.
zeros_like
(
prefix_before_all
))
# [B]
prefix_before_per_token
=
prefix_before
.
index_select
(
0
,
req_ids
)
# [T]
local_rank
=
keep_prefix
-
prefix_before_per_token
-
1
# [T]
return
keep_mask
,
local_rank
,
keep_len
.
to
(
torch
.
int32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment