Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3adc766e
Commit
3adc766e
authored
Jan 27, 2026
by
laibao
Browse files
refactor: 抽离 flash_attn 的 KV compression 逻辑到 vllm/v1/kv_compression
parent
9db5ff3b
Changes
10
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1076 additions
and
656 deletions
+1076
-656
vllm/v1/attention/backends/flash_attn.py
vllm/v1/attention/backends/flash_attn.py
+46
-656
vllm/v1/kv_compression/compaction_step.py
vllm/v1/kv_compression/compaction_step.py
+70
-0
vllm/v1/kv_compression/flash_attn_hooks.py
vllm/v1/kv_compression/flash_attn_hooks.py
+186
-0
vllm/v1/kv_compression/forward_context.py
vllm/v1/kv_compression/forward_context.py
+66
-0
vllm/v1/kv_compression/kv_cache_view.py
vllm/v1/kv_compression/kv_cache_view.py
+38
-0
vllm/v1/kv_compression/metadata.py
vllm/v1/kv_compression/metadata.py
+75
-0
vllm/v1/kv_compression/prompt_end.py
vllm/v1/kv_compression/prompt_end.py
+187
-0
vllm/v1/kv_compression/slot_mapping.py
vllm/v1/kv_compression/slot_mapping.py
+127
-0
vllm/v1/kv_compression/snapkv_score.py
vllm/v1/kv_compression/snapkv_score.py
+150
-0
vllm/v1/kv_compression/topk_select.py
vllm/v1/kv_compression/topk_select.py
+131
-0
No files found.
vllm/v1/attention/backends/flash_attn.py
View file @
3adc766e
This diff is collapsed.
Click to expand it.
vllm/v1/kv_compression/compaction_step.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
torch
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.slot_mapping
import
topk_kv_compact_slot_mapping
from
vllm.v1.kv_compression.snapkv_score
import
snapkv_like_token_scores
def
snapkv_window_for_topk_budget
(
*
,
topk_budget
:
torch
.
Tensor
,
# [B] int32
window
:
int
,
)
->
torch
.
Tensor
:
"""Build per-request SnapKV window sizes for mixed batches.
Requests with a zero Top-K budget do not need token scores; setting their
window to 0 lets the Triton scoring kernel early-return.
"""
return
torch
.
where
(
topk_budget
>
0
,
torch
.
full_like
(
topk_budget
,
int
(
window
)),
torch
.
zeros_like
(
topk_budget
),
)
def
compute_compact_dst_slots_for_step
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] for this step
key
:
torch
.
Tensor
,
# [T, Hkv, D] for this step
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
topk_budget_max
:
int
,
max_query_len
:
int
,
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""Compute per-token KV compaction destinations for one step."""
token_scores
=
None
if
int
(
topk_budget_max
)
>
0
:
w
=
snapkv_window_for_topk_budget
(
topk_budget
=
topk_budget
,
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
),
)
token_scores
=
snapkv_like_token_scores
(
query
=
query
,
key
=
key
,
query_start_loc
=
query_start_loc
,
window
=
w
,
sm_scale
=
float
(
sm_scale
),
)
return
topk_kv_compact_slot_mapping
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
query_start_loc
=
query_start_loc
,
seq_lens
=
seq_lens
,
block_table
=
block_table
,
block_size
=
int
(
block_size
),
max_query_len
=
int
(
max_query_len
),
topk_budget_max
=
int
(
topk_budget_max
),
)
vllm/v1/kv_compression/flash_attn_hooks.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
,
Protocol
import
torch
import
vllm.envs
as
envs
from
vllm.forward_context
import
get_forward_context
from
vllm.platforms
import
current_platform
from
vllm.v1.kv_compression.compaction_step
import
compute_compact_dst_slots_for_step
from
vllm.v1.kv_compression.forward_context
import
(
get_kv_compression_compact_slots
,
get_kv_compression_prompt_payload
,
set_kv_compression_compact_slots
,
set_kv_compression_prompt_payload
,
)
from
vllm.v1.kv_compression.prompt_end
import
compute_prompt_end_indices
from
vllm.v1.kv_compression.slot_mapping
import
kv_compaction_dst_rewrite_mapping
class
_ReshapeAndCacheFn
(
Protocol
):
def
__call__
(
self
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
)
->
None
:
...
def
maybe_compute_prompt_end_payload_flash_attn
(
*
,
kv_sharing_target_layer_name
:
Optional
[
str
],
query
:
torch
.
Tensor
,
num_actual_tokens
:
int
,
key_cache
:
torch
.
Tensor
,
cache_block_size
:
int
,
attn_metadata
:
Any
,
sm_scale
:
float
,
)
->
None
:
"""Compute and stash prompt-end Top-K indices for chunked-prefill scheme 3.
The payload is cached in the forward context and later consumed by the
model runner to perform one-shot prompt KV compaction before the first
decode step.
"""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
kv_sharing_target_layer_name
is
not
None
:
return
prompt_end
=
getattr
(
attn_metadata
,
"kv_compression_prompt_end"
,
None
)
prompt_lens
=
getattr
(
attn_metadata
,
"kv_compression_prompt_lens"
,
None
)
topk_keep
=
getattr
(
attn_metadata
,
"kv_compression_prompt_topk_keep"
,
None
)
if
prompt_end
is
None
or
prompt_lens
is
None
or
topk_keep
is
None
:
return
forward_context
=
get_forward_context
()
if
get_kv_compression_prompt_payload
(
forward_context
)
is
not
None
:
return
payload
=
compute_prompt_end_indices
(
query
=
query
[:
num_actual_tokens
],
key_cache
=
key_cache
,
block_size
=
cache_block_size
,
query_start_loc
=
attn_metadata
.
query_start_loc
,
block_table
=
attn_metadata
.
block_table
,
prompt_end
=
prompt_end
,
prompt_lens
=
prompt_lens
,
topk_keep
=
topk_keep
,
topk_keep_max
=
getattr
(
attn_metadata
,
"kv_compression_prompt_topk_keep_max"
,
None
),
sm_scale
=
sm_scale
,
)
if
payload
is
not
None
:
set_kv_compression_prompt_payload
(
forward_context
,
payload
)
def
maybe_compact_kv_cache_flash_attn
(
*
,
kv_sharing_target_layer_name
:
Optional
[
str
],
layer
:
Any
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
num_actual_tokens
:
int
,
cache_block_size
:
int
,
attn_metadata
:
Any
,
sm_scale
:
float
,
kv_cache_dtype
:
str
,
reshape_and_cache
:
_ReshapeAndCacheFn
,
)
->
None
:
"""Optional per-step KV compaction for scheme 1/2 token-shared selection."""
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
or
kv_sharing_target_layer_name
is
not
None
:
return
must_keep
=
getattr
(
attn_metadata
,
"kv_compression_must_keep"
,
None
)
topk_budget
=
getattr
(
attn_metadata
,
"kv_compression_topk_budget"
,
None
)
if
must_keep
is
None
or
topk_budget
is
None
:
return
forward_context
=
get_forward_context
()
per_layer_topk
=
envs
.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER
dst
=
get_kv_compression_compact_slots
(
forward_context
,
per_layer_topk
=
per_layer_topk
,
layer
=
layer
,
)
if
dst
is
None
:
topk_budget_max
=
int
(
getattr
(
attn_metadata
,
"kv_compression_topk_budget_max"
,
0
)
or
0
)
dst
=
compute_compact_dst_slots_for_step
(
query
=
query
[:
num_actual_tokens
],
key
=
key
[:
num_actual_tokens
],
query_start_loc
=
attn_metadata
.
query_start_loc
,
seq_lens
=
attn_metadata
.
seq_lens
,
block_table
=
attn_metadata
.
block_table
,
block_size
=
cache_block_size
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
topk_budget_max
=
topk_budget_max
,
max_query_len
=
attn_metadata
.
max_query_len
,
sm_scale
=
sm_scale
,
)
set_kv_compression_compact_slots
(
forward_context
,
per_layer_topk
=
per_layer_topk
,
layer
=
layer
,
dst
=
dst
,
)
if
dst
is
None
:
return
src
=
attn_metadata
.
slot_mapping
dst_rewrite
=
kv_compaction_dst_rewrite_mapping
(
dst_slots
=
dst
,
src_slots
=
src
)
if
not
current_platform
.
is_rocm
():
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
return
# ROCm: optionally prefer the optimized reshape-and-cache kernel.
if
(
envs
.
VLLM_USE_OPT_RESHAPE_AND_CACHE
and
key
.
dtype
==
value
.
dtype
and
key
.
dtype
==
torch
.
float16
):
from
lightop
import
reshape_and_cache_cuda
reshape_and_cache_cuda
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
else
:
reshape_and_cache
(
key
,
value
,
key_cache
,
value_cache
,
dst_rewrite
,
kv_cache_dtype
,
layer
.
_k_scale
,
layer
.
_v_scale
,
)
vllm/v1/kv_compression/forward_context.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Any
,
Optional
import
torch
_PROMPT_PAYLOAD_ATTR
=
"_kv_compression_prompt_payload"
_COMPACT_SLOTS_ATTR
=
"_kv_compression_compact_slots"
_COMPACT_SLOTS_BY_LAYER_ATTR
=
"_kv_compression_compact_slots_by_layer"
def
get_kv_compression_prompt_payload
(
forward_context
:
Any
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
return
getattr
(
forward_context
,
_PROMPT_PAYLOAD_ATTR
,
None
)
def
set_kv_compression_prompt_payload
(
forward_context
:
Any
,
payload
:
dict
[
str
,
torch
.
Tensor
],
)
->
None
:
setattr
(
forward_context
,
_PROMPT_PAYLOAD_ATTR
,
payload
)
def
_kv_compression_layer_key
(
layer
:
Any
)
->
str
:
layer_name
=
getattr
(
layer
,
"layer_name"
,
None
)
if
layer_name
is
None
:
layer_name
=
str
(
id
(
layer
))
return
str
(
layer_name
)
def
get_kv_compression_compact_slots
(
forward_context
:
Any
,
*
,
per_layer_topk
:
bool
,
layer
:
Any
,
)
->
Optional
[
torch
.
Tensor
]:
if
per_layer_topk
:
dst_by_layer
=
getattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
None
)
if
dst_by_layer
is
None
:
return
None
return
dst_by_layer
.
get
(
_kv_compression_layer_key
(
layer
))
return
getattr
(
forward_context
,
_COMPACT_SLOTS_ATTR
,
None
)
def
set_kv_compression_compact_slots
(
forward_context
:
Any
,
*
,
per_layer_topk
:
bool
,
layer
:
Any
,
dst
:
torch
.
Tensor
,
)
->
None
:
if
per_layer_topk
:
dst_by_layer
=
getattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
None
)
if
dst_by_layer
is
None
:
dst_by_layer
=
{}
setattr
(
forward_context
,
_COMPACT_SLOTS_BY_LAYER_ATTR
,
dst_by_layer
)
dst_by_layer
[
_kv_compression_layer_key
(
layer
)]
=
dst
else
:
setattr
(
forward_context
,
_COMPACT_SLOTS_ATTR
,
dst
)
vllm/v1/kv_compression/kv_cache_view.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
torch
from
vllm.platforms
import
current_platform
def
paged_k_cache_view_for_triton_gather
(
*
,
key_cache
:
torch
.
Tensor
,
block_size
:
int
,
)
->
torch
.
Tensor
:
"""Return a KV-cache key view in [num_blocks, H, block_size, D] layout.
Supports both:
- [num_blocks, block_size, H, D] (typical CUDA FlashAttention v1 layout)
- [num_blocks, H, block_size, D] (ROCm FlashAttention v1, or external
connectors that expose the cache in HND shape)
"""
if
key_cache
.
ndim
!=
4
:
raise
ValueError
(
"key_cache must be a 4D tensor."
)
# Common case: [B, T, H, D] -> [B, H, T, D]
if
int
(
key_cache
.
shape
[
1
])
==
int
(
block_size
):
return
key_cache
.
permute
(
0
,
2
,
1
,
3
)
# Already in [B, H, T, D] (ROCm / HND-shaped external caches).
if
int
(
key_cache
.
shape
[
2
])
==
int
(
block_size
):
return
key_cache
# Fallback: preserve historical behavior.
if
current_platform
.
is_rocm
():
return
key_cache
return
key_cache
.
permute
(
0
,
2
,
1
,
3
)
vllm/v1/kv_compression/metadata.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
dataclasses
import
dataclass
from
typing
import
Any
,
Optional
import
torch
import
vllm.envs
as
envs
@
dataclass
class
KVCompressionAttentionMetadata
:
"""Per-batch KV compression metadata consumed by attention backends."""
must_keep
:
Optional
[
torch
.
Tensor
]
=
None
topk_budget
:
Optional
[
torch
.
Tensor
]
=
None
topk_budget_max
:
Optional
[
int
]
=
None
prompt_end
:
Optional
[
torch
.
Tensor
]
=
None
prompt_lens
:
Optional
[
torch
.
Tensor
]
=
None
prompt_topk_keep
:
Optional
[
torch
.
Tensor
]
=
None
prompt_topk_keep_max
:
Optional
[
int
]
=
None
def
build_kv_compression_attn_metadata
(
*
,
runner
:
Any
,
num_reqs
:
int
,
num_actual_tokens
:
int
,
)
->
KVCompressionAttentionMetadata
:
"""Build KV compression metadata for one attention step.
This helper keeps backend code thin and centralizes the logic for selecting
between per-step compaction (scheme 1/2) and prompt-end one-shot scoring
(scheme 3).
"""
meta
=
KVCompressionAttentionMetadata
()
if
not
envs
.
VLLM_ENABLE_KV_COMPRESSION
:
return
meta
# Scheme 1/2: compute compaction destinations every step.
if
getattr
(
runner
,
"kv_compression_needs_compaction"
,
False
):
meta
.
must_keep
=
runner
.
kv_compression_must_keep
[:
num_actual_tokens
]
meta
.
topk_budget
=
runner
.
kv_compression_topk_budget
[:
num_reqs
]
# Avoid device->host sync by reading from the CPU staging buffer.
if
num_reqs
>
0
:
meta
.
topk_budget_max
=
int
(
runner
.
kv_compression_topk_budget_np
[:
num_reqs
].
max
())
else
:
meta
.
topk_budget_max
=
0
return
meta
# Scheme 3: compute global prompt indices only on the last prefill chunk,
# and perform the actual cache compaction before the first decode step.
scheduler_config
=
getattr
(
runner
,
"scheduler_config"
,
None
)
if
scheduler_config
is
None
or
not
getattr
(
scheduler_config
,
"chunked_prefill_enabled"
,
False
):
return
meta
if
num_reqs
<=
0
:
return
meta
if
not
runner
.
kv_compression_prompt_end_np
[:
num_reqs
].
any
():
return
meta
meta
.
prompt_end
=
runner
.
kv_compression_prompt_end
[:
num_reqs
]
meta
.
prompt_lens
=
runner
.
kv_compression_prompt_lens
[:
num_reqs
]
meta
.
prompt_topk_keep
=
runner
.
kv_compression_prompt_topk_keep
[:
num_reqs
]
meta
.
prompt_topk_keep_max
=
int
(
getattr
(
runner
,
"kv_compression_prompt_topk_keep_max"
,
0
)
or
0
)
return
meta
vllm/v1/kv_compression/prompt_end.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
import
vllm.envs
as
envs
from
vllm.v1.kv_compression.kv_cache_view
import
paged_k_cache_view_for_triton_gather
from
vllm.v1.kv_compression.snapkv_score
import
snapkv_query_aware_token_scores
from
vllm.v1.kv_compression.topk_select
import
(
_packed_varlen_coords
,
_topk_keep_mask_and_local_rank
)
def
_prompt_end_topk_keep_indices
(
*
,
token_scores
:
torch
.
Tensor
,
# [T] float32
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32 (candidates only)
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
topk_keep_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device
=
token_scores
.
device
B
=
int
(
prompt_lens
.
numel
())
if
B
==
0
:
empty
=
torch
.
empty
((
0
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
int32
)
prompt_lens_i64
=
prompt_lens
.
to
(
torch
.
long
)
cu
=
torch
.
zeros
((
B
+
1
,
),
device
=
device
,
dtype
=
torch
.
long
)
cu
[
1
:]
=
torch
.
cumsum
(
prompt_lens_i64
,
dim
=
0
)
T
=
int
(
token_scores
.
numel
())
if
T
==
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
starts
,
_
,
lengths
,
req_ids
,
pos_in_req
=
_packed_varlen_coords
(
cu_seqlens
=
cu
,
total_tokens
=
T
,
)
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_prefix
,
0
))
suffix
=
torch
.
clamp
(
prompt_lens_i64
,
min
=
0
).
clamp_max
(
max
(
protected_suffix
,
0
))
suffix_start
=
(
prompt_lens_i64
-
suffix
).
clamp_min
(
0
)
prefix_len_t
=
prefix_len
.
index_select
(
0
,
req_ids
)
suffix_start_t
=
suffix_start
.
index_select
(
0
,
req_ids
)
must_keep
=
(
pos_in_req
<
prefix_len_t
)
|
(
pos_in_req
>=
suffix_start_t
)
if
keep_last_token
:
last
=
(
prompt_lens_i64
-
1
).
clamp_min
(
0
)
last_t
=
last
.
index_select
(
0
,
req_ids
)
must_keep
|=
pos_in_req
==
last_t
keep_mask
,
local_rank
,
keep_len
=
_topk_keep_mask_and_local_rank
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_keep
,
starts
=
starts
,
lengths
=
lengths
,
req_ids
=
req_ids
,
pos_in_req
=
pos_in_req
,
max_len
=
int
(
prompt_lens_i64
.
max
().
item
()),
topk_budget_max
=
topk_keep_max
,
)
keep_max_len
=
int
(
keep_len
.
max
().
item
())
if
B
>
0
else
0
if
keep_max_len
<=
0
:
empty
=
torch
.
empty
((
B
,
0
),
device
=
device
,
dtype
=
torch
.
int32
)
return
empty
,
keep_len
idx_sorted
=
torch
.
zeros
((
B
,
keep_max_len
),
device
=
device
,
dtype
=
torch
.
int32
)
lin_out
=
(
req_ids
*
keep_max_len
+
local_rank
).
masked_select
(
keep_mask
)
vals
=
pos_in_req
.
to
(
torch
.
int32
).
masked_select
(
keep_mask
)
idx_sorted
.
view
(
-
1
).
scatter_
(
0
,
lin_out
,
vals
)
return
idx_sorted
,
keep_len
def
compute_prompt_end_indices
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D] scheduled tokens for this step
key_cache
:
torch
.
Tensor
,
# layer KV cache view (platform-dependent)
block_size
:
int
,
query_start_loc
:
torch
.
Tensor
,
# [B+1] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
prompt_end
:
torch
.
Tensor
,
# [B] bool
prompt_lens
:
torch
.
Tensor
,
# [B] int32
topk_keep
:
torch
.
Tensor
,
# [B] int32
topk_keep_max
:
Optional
[
int
],
sm_scale
:
float
,
)
->
Optional
[
dict
[
str
,
torch
.
Tensor
]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device
=
query
.
device
if
prompt_end
.
numel
()
==
0
:
return
None
sel
=
torch
.
nonzero
(
prompt_end
,
as_tuple
=
False
).
flatten
()
if
int
(
sel
.
numel
())
==
0
:
return
None
window
=
int
(
envs
.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW
)
keep_last
=
bool
(
envs
.
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
)
protected_prefix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_PREFIX
)
protected_suffix
=
int
(
envs
.
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
)
# Build packed Q window (last `window` queries per selected request).
sel_list
=
sel
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
qsl
=
query_start_loc
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
q_chunks
=
[]
cu_q
=
[
0
]
w_list
=
[]
for
b
in
sel_list
:
s
=
int
(
qsl
[
b
])
e
=
int
(
qsl
[
b
+
1
])
q_len
=
max
(
0
,
e
-
s
)
win
=
min
(
window
,
q_len
)
w_list
.
append
(
int
(
win
))
if
win
>
0
:
q_chunks
.
append
(
query
[
e
-
win
:
e
])
cu_q
.
append
(
cu_q
[
-
1
]
+
int
(
win
))
if
cu_q
[
-
1
]
<=
0
:
return
None
q_packed
=
torch
.
cat
(
q_chunks
,
dim
=
0
)
if
q_chunks
else
query
[:
0
]
cu_seqlens_q
=
torch
.
tensor
(
cu_q
,
device
=
device
,
dtype
=
torch
.
int32
)
w
=
torch
.
tensor
(
w_list
,
device
=
device
,
dtype
=
torch
.
int32
)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel
=
prompt_lens
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
topk_keep_sel
=
topk_keep
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
cu_seqlens_k
=
torch
.
zeros
((
int
(
prompt_lens_sel
.
numel
())
+
1
,
),
device
=
device
,
dtype
=
torch
.
int32
)
if
int
(
prompt_lens_sel
.
numel
())
>
0
:
cu_seqlens_k
[
1
:]
=
torch
.
cumsum
(
prompt_lens_sel
,
dim
=
0
)
block_table_sel
=
block_table
.
index_select
(
0
,
sel
).
to
(
torch
.
int32
)
key_cache_view
=
paged_k_cache_view_for_triton_gather
(
key_cache
=
key_cache
,
block_size
=
int
(
block_size
),
)
from
vllm.v1.kv_compression.kv_cache_triton
import
(
gather_k_to_packed_triton
)
k_packed
=
gather_k_to_packed_triton
(
key_cache_view
,
block_table_sel
,
prompt_lens_sel
,
cu_seqlens_k
,
)
token_scores
=
snapkv_query_aware_token_scores
(
query
=
q_packed
,
key
=
k_packed
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
window
=
w
,
sm_scale
=
float
(
sm_scale
),
)
idx_sorted
,
keep_len
=
_prompt_end_topk_keep_indices
(
token_scores
=
token_scores
,
prompt_lens
=
prompt_lens_sel
,
topk_keep
=
topk_keep_sel
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last
,
topk_keep_max
=
topk_keep_max
,
)
return
{
"req_indices"
:
sel
.
to
(
torch
.
int32
),
"idx_sorted"
:
idx_sorted
,
# [B_sel, K_max] int32
"keep_len"
:
keep_len
,
# [B_sel] int32
"prompt_lens"
:
prompt_lens_sel
,
# [B_sel] int32
}
vllm/v1/kv_compression/slot_mapping.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
from
vllm.v1.kv_compression.topk_select
import
(
_packed_varlen_coords
,
_topk_keep_mask_and_local_rank
)
def
_dst_slots_from_keep_mask_and_local_rank
(
*
,
keep_mask
:
torch
.
Tensor
,
# [T] bool
local_rank
:
torch
.
Tensor
,
# [T] int64
seq_lens
:
torch
.
Tensor
,
# [B] int32
lengths
:
torch
.
Tensor
,
# [B] int64
req_ids
:
torch
.
Tensor
,
# [T] int64
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
block_size
:
int
,
)
->
torch
.
Tensor
:
"""Convert keep_mask/local_rank into a per-token KV destination slot."""
device
=
keep_mask
.
device
T
=
int
(
keep_mask
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
:
return
dst_slots
B
=
int
(
seq_lens
.
numel
())
if
B
==
0
:
return
dst_slots
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv
=
(
seq_lens
[:
B
].
to
(
torch
.
long
)
-
lengths
.
to
(
torch
.
long
)).
clamp_min
(
0
)
base_kv_per_token
=
base_kv
.
index_select
(
0
,
req_ids
)
# [T]
dest_pos
=
base_kv_per_token
+
local_rank
# [T]
dest_block_idx
=
dest_pos
//
block_size
dest_off
=
dest_pos
-
dest_block_idx
*
block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks
=
int
(
block_table
.
shape
[
1
])
dest_block_idx_safe
=
dest_block_idx
.
clamp_
(
0
,
max_blocks
-
1
).
to
(
torch
.
long
)
block_nums
=
block_table
[
req_ids
,
dest_block_idx_safe
]
dest_slot
=
block_nums
.
to
(
torch
.
long
)
*
block_size
+
dest_off
return
torch
.
where
(
keep_mask
,
dest_slot
.
to
(
torch
.
int64
),
dst_slots
)
def
topk_kv_compact_slot_mapping
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
max_query_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
or
B
==
0
:
return
dst_slots
starts
,
_
,
lengths
,
req_ids
,
pos_in_req
=
_packed_varlen_coords
(
cu_seqlens
=
query_start_loc
,
total_tokens
=
T
,
)
if
lengths
.
numel
()
==
0
:
return
dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max
=
int
(
max_query_len
)
if
max_query_len
is
not
None
else
int
(
lengths
.
max
().
item
())
if
L_max
<=
0
:
return
dst_slots
keep_mask
,
local_rank
,
_
=
_topk_keep_mask_and_local_rank
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
starts
=
starts
,
lengths
=
lengths
,
req_ids
=
req_ids
,
pos_in_req
=
pos_in_req
,
max_len
=
L_max
,
topk_budget_max
=
topk_budget_max
,
)
return
_dst_slots_from_keep_mask_and_local_rank
(
keep_mask
=
keep_mask
,
local_rank
=
local_rank
,
seq_lens
=
seq_lens
[:
B
],
lengths
=
lengths
,
req_ids
=
req_ids
,
block_table
=
block_table
,
block_size
=
int
(
block_size
),
)
def
kv_compaction_dst_rewrite_mapping
(
*
,
dst_slots
:
torch
.
Tensor
,
# [T] int64
src_slots
:
torch
.
Tensor
,
# [T] int64
)
->
torch
.
Tensor
:
"""Filter a dst slot mapping so only moved kept tokens are rewritten.
Non-rewrite tokens are marked as -1, which the cache kernels treat as
padding and skip.
"""
rewrite_mask
=
(
dst_slots
>=
0
)
&
(
dst_slots
!=
src_slots
)
return
torch
.
where
(
rewrite_mask
,
dst_slots
,
-
1
)
vllm/v1/kv_compression/snapkv_score.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Union
import
torch
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
logger
=
init_logger
(
__name__
)
_DISABLE_SNAPKV_TRITON
:
bool
=
False
def
snapkv_query_aware_token_scores
(
*
,
query
:
torch
.
Tensor
,
# [N_q, Hq, D]
key
:
torch
.
Tensor
,
# [N_k, Hkv, D]
cu_seqlens_q
:
torch
.
Tensor
,
# [B+1]
cu_seqlens_k
:
torch
.
Tensor
,
# [B+1]
window
:
Union
[
int
,
torch
.
Tensor
],
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""Compute token-shared SnapKV scores for packed, varlen q/k inputs.
Returns a [N_k] float32 tensor, reduced across TP ranks so every rank makes
an identical Top-K selection.
"""
global
_DISABLE_SNAPKV_TRITON
device
=
query
.
device
if
query
.
ndim
!=
3
or
key
.
ndim
!=
3
:
raise
ValueError
(
"query and key must be 3D tensors."
)
_
,
Hq
,
D
=
query
.
shape
N_k
,
Hkv
,
Dk
=
key
.
shape
if
D
!=
Dk
:
raise
ValueError
(
"query and key must have the same head size."
)
if
Hq
%
Hkv
!=
0
:
raise
ValueError
(
"Query heads must be a multiple of KV heads."
)
if
cu_seqlens_q
.
numel
()
!=
cu_seqlens_k
.
numel
():
raise
ValueError
(
"cu_seqlens_q and cu_seqlens_k must match."
)
# NOTE: Triton SnapKV scoring on ROCm is experimental. It is enabled by
# default (uses a ROCm-safe kernel); set
# VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM=0 to force the PyTorch
# reference implementation.
if
(
HAS_TRITON
and
not
_DISABLE_SNAPKV_TRITON
and
device
.
type
==
"cuda"
and
(
not
current_platform
.
is_rocm
()
or
envs
.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM
)
and
query
.
stride
(
-
1
)
==
1
and
key
.
stride
(
-
1
)
==
1
):
try
:
from
vllm.v1.kv_compression.snapkv_triton
import
(
query_aware_key_scores
)
w
=
int
(
window
)
if
isinstance
(
window
,
int
)
else
window
scores_per_head
=
query_aware_key_scores
(
q
=
query
,
k
=
key
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_k
=
cu_seqlens_k
,
w
=
w
,
sm_scale
=
float
(
sm_scale
),
pool
=
False
,
protect_last
=
False
,
normalize
=
False
,
)
token_scores
=
scores_per_head
.
sum
(
dim
=
1
)
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
token_scores
)
except
Exception
as
e
:
_DISABLE_SNAPKV_TRITON
=
True
logger
.
warning
(
"Triton SnapKV scoring failed; falling back to PyTorch. "
"Error: %s"
,
e
)
group
=
Hq
//
Hkv
# Read boundaries and (optional) per-request window sizes on host.
qsl
=
cu_seqlens_q
.
tolist
()
ksl
=
cu_seqlens_k
.
tolist
()
B
=
len
(
qsl
)
-
1
if
len
(
ksl
)
-
1
!=
B
:
raise
ValueError
(
"cu_seqlens_q and cu_seqlens_k must match."
)
wsl
=
None
if
not
isinstance
(
window
,
int
):
if
int
(
window
.
numel
())
!=
B
:
raise
ValueError
(
"window must be a scalar int or have shape [B]."
)
wsl
=
window
.
to
(
device
=
"cpu"
,
dtype
=
torch
.
int64
).
tolist
()
scores
=
torch
.
zeros
((
N_k
,
),
device
=
device
,
dtype
=
torch
.
float32
)
for
b
in
range
(
B
):
qs
=
int
(
qsl
[
b
])
qe
=
int
(
qsl
[
b
+
1
])
ks
=
int
(
ksl
[
b
])
ke
=
int
(
ksl
[
b
+
1
])
q_len
=
qe
-
qs
k_len
=
ke
-
ks
if
q_len
<=
0
or
k_len
<=
0
:
continue
win_b
=
int
(
window
)
if
wsl
is
None
else
int
(
wsl
[
b
])
if
win_b
<=
0
:
continue
win
=
min
(
win_b
,
q_len
,
k_len
)
k_eff_end
=
ke
-
win
if
k_eff_end
<=
ks
:
continue
q_win
=
query
[
qe
-
win
:
qe
]
# [win, Hq, D]
# Aggregate query heads to KV heads (token-shared selection).
q_win
=
q_win
.
reshape
(
win
,
Hkv
,
group
,
D
).
mean
(
dim
=
2
)
# [win, Hkv, D]
k_eff
=
key
[
ks
:
k_eff_end
]
# [K_eff, Hkv, D]
qh
=
q_win
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, win, D]
kh
=
k_eff
.
permute
(
1
,
0
,
2
).
to
(
torch
.
float32
)
# [Hkv, K_eff, D]
logits
=
torch
.
matmul
(
qh
,
kh
.
transpose
(
1
,
2
))
*
float
(
sm_scale
)
probs
=
torch
.
softmax
(
logits
,
dim
=-
1
)
scores
[
ks
:
k_eff_end
]
=
probs
.
sum
(
dim
=
1
).
sum
(
dim
=
0
)
from
vllm.distributed.parallel_state
import
get_tp_group
return
get_tp_group
().
all_reduce
(
scores
)
def
snapkv_like_token_scores
(
*
,
query
:
torch
.
Tensor
,
# [T, Hq, D]
key
:
torch
.
Tensor
,
# [T, Hkv, D]
query_start_loc
:
torch
.
Tensor
,
# [B+1]
window
:
Union
[
int
,
torch
.
Tensor
],
sm_scale
:
float
,
)
->
torch
.
Tensor
:
"""SnapKV-like token scores when q/k share the same packed layout."""
return
snapkv_query_aware_token_scores
(
query
=
query
,
key
=
key
,
cu_seqlens_q
=
query_start_loc
,
cu_seqlens_k
=
query_start_loc
,
window
=
window
,
sm_scale
=
sm_scale
,
)
vllm/v1/kv_compression/topk_select.py
0 → 100644
View file @
3adc766e
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
def
_packed_varlen_coords
(
*
,
cu_seqlens
:
torch
.
Tensor
,
# [B+1]
total_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute packed varlen segment coordinates.
Returns:
starts: [B] int64, segment start offsets (inclusive)
ends: [B] int64, segment end offsets (exclusive)
lengths: [B] int64, segment lengths (ends - starts)
req_ids: [T] int64, request id for each token in packed [0, T)
pos_in_req: [T] int64, position within its request segment
"""
device
=
cu_seqlens
.
device
B
=
int
(
cu_seqlens
.
numel
()
-
1
)
if
B
<=
0
:
empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
t_empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
return
empty
,
empty
,
empty
,
t_empty
,
t_empty
starts
=
cu_seqlens
[:
B
].
to
(
torch
.
long
)
ends
=
cu_seqlens
[
1
:
B
+
1
].
to
(
torch
.
long
)
lengths
=
ends
-
starts
if
total_tokens
<=
0
:
t_empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
return
starts
,
ends
,
lengths
,
t_empty
,
t_empty
token_idx
=
torch
.
arange
(
total_tokens
,
device
=
device
,
dtype
=
torch
.
long
)
req_ids
=
torch
.
bucketize
(
token_idx
,
ends
,
right
=
True
)
# [T]
start_per_token
=
starts
.
index_select
(
0
,
req_ids
)
pos_in_req
=
token_idx
-
start_per_token
return
starts
,
ends
,
lengths
,
req_ids
,
pos_in_req
def
_topk_keep_mask_and_local_rank
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
starts
:
torch
.
Tensor
,
# [B] int64
lengths
:
torch
.
Tensor
,
# [B] int64
req_ids
:
torch
.
Tensor
,
# [T] int64
pos_in_req
:
torch
.
Tensor
,
# [T] int64
max_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute keep_mask/local_rank for token-shared Top-K selection.
Returns:
keep_mask: [T] bool, selected tokens (includes must_keep)
local_rank: [T] int64, rank among kept tokens within each request
keep_len: [B] int32, number of kept tokens per request
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
keep_mask
=
must_keep
.
clone
()
if
T
==
0
or
B
==
0
:
local_rank
=
torch
.
empty
((
T
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
return
keep_mask
,
local_rank
,
keep_len
if
max_len
is
None
:
L_max
=
int
(
lengths
.
max
().
item
())
if
lengths
.
numel
()
>
0
else
0
else
:
L_max
=
int
(
max_len
)
if
L_max
<
0
:
L_max
=
0
must_keep_counts
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
must_keep_counts
.
scatter_add_
(
0
,
req_ids
,
must_keep
.
to
(
torch
.
long
))
cand_counts
=
(
lengths
.
to
(
torch
.
long
)
-
must_keep_counts
).
clamp_min
(
0
)
k_eff
=
torch
.
minimum
(
topk_budget
.
to
(
torch
.
long
).
clamp_min
(
0
),
cand_counts
)
# CPU-known bound avoids a device->host sync; clamp for safety.
if
topk_budget_max
is
None
:
k_max
=
int
(
k_eff
.
max
().
item
())
if
k_eff
.
numel
()
>
0
else
0
else
:
k_max
=
int
(
topk_budget_max
)
if
k_max
<
0
:
k_max
=
0
if
k_max
>
L_max
:
k_max
=
L_max
if
k_max
>
0
:
if
token_scores
is
None
:
raise
ValueError
(
"token_scores must be provided when k_max > 0."
)
masked_scores
=
token_scores
.
to
(
torch
.
float32
).
masked_fill
(
must_keep
,
float
(
"-inf"
))
scores_flat
=
masked_scores
.
new_full
((
B
*
L_max
,
),
float
(
"-inf"
))
linear
=
req_ids
*
L_max
+
pos_in_req
scores_flat
[
linear
]
=
masked_scores
scores
=
scores_flat
.
view
(
B
,
L_max
)
topk_pos
=
torch
.
topk
(
scores
,
k
=
k_max
,
dim
=
1
).
indices
# [B, k_max]
col_mask
=
torch
.
arange
(
k_max
,
device
=
device
).
unsqueeze
(
0
)
<
k_eff
.
unsqueeze
(
1
)
global_sel
=
starts
.
unsqueeze
(
1
)
+
topk_pos
.
to
(
torch
.
long
)
# [B,k_max]
flat_idx
=
global_sel
.
reshape
(
-
1
).
clamp_
(
0
,
T
-
1
)
flat_val
=
col_mask
.
reshape
(
-
1
).
to
(
torch
.
int32
)
tmp
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
int32
)
tmp
.
scatter_add_
(
0
,
flat_idx
,
flat_val
)
keep_mask
|=
tmp
>
0
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
.
scatter_add_
(
0
,
req_ids
,
keep_mask
.
to
(
torch
.
long
))
# Stable, order-preserving local rank using segment-local prefix sums.
keep_prefix
=
torch
.
cumsum
(
keep_mask
.
to
(
torch
.
long
),
dim
=
0
)
# [T]
start_minus_1
=
(
starts
-
1
).
clamp_min
(
0
)
prefix_before_all
=
keep_prefix
.
index_select
(
0
,
start_minus_1
)
prefix_before
=
torch
.
where
(
starts
>
0
,
prefix_before_all
,
torch
.
zeros_like
(
prefix_before_all
))
# [B]
prefix_before_per_token
=
prefix_before
.
index_select
(
0
,
req_ids
)
# [T]
local_rank
=
keep_prefix
-
prefix_before_per_token
-
1
# [T]
return
keep_mask
,
local_rank
,
keep_len
.
to
(
torch
.
int32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment