Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3da2c829
"tests/fault_tolerance/vscode:/vscode.git/clone" did not exist on "8ed69ea2f8f73a00512bfe15045e7803bb9b63cb"
Commit
3da2c829
authored
Feb 24, 2026
by
laibao
Browse files
feat(kvpress): 新增 Top-K budget 与选择工具
parent
d41ca128
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
514 additions
and
0 deletions
+514
-0
vllm/v1/kv_compression/__init__.py
vllm/v1/kv_compression/__init__.py
+8
-0
vllm/v1/kv_compression/budget.py
vllm/v1/kv_compression/budget.py
+248
-0
vllm/v1/kv_compression/slot_mapping.py
vllm/v1/kv_compression/slot_mapping.py
+127
-0
vllm/v1/kv_compression/topk_select.py
vllm/v1/kv_compression/topk_select.py
+131
-0
No files found.
vllm/v1/kv_compression/__init__.py
0 → 100644
View file @
3da2c829
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.budget
import
(
# noqa: F401
compute_topk_budget_step
,
count_prompt_must_keep_in_range
,
)
vllm/v1/kv_compression/budget.py
0 → 100644
View file @
3da2c829
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
math
def
_clamp_int
(
value
:
int
,
lo
:
int
,
hi
:
int
)
->
int
:
if
value
<
lo
:
return
lo
if
value
>
hi
:
return
hi
return
value
def
_intersection_len
(
a0
:
int
,
a1
:
int
,
b0
:
int
,
b1
:
int
)
->
int
:
start
=
a0
if
a0
>
b0
else
b0
end
=
a1
if
a1
<
b1
else
b1
return
max
(
0
,
end
-
start
)
def
_protected_prefix_len
(
prompt_len
:
int
,
protected_prefix
:
int
)
->
int
:
return
min
(
max
(
protected_prefix
,
0
),
max
(
prompt_len
,
0
))
def
_protected_suffix_start
(
prompt_len
:
int
,
protected_suffix
:
int
)
->
int
:
prompt_len
=
max
(
prompt_len
,
0
)
suffix
=
min
(
max
(
protected_suffix
,
0
),
prompt_len
)
return
prompt_len
-
suffix
def
count_prompt_must_keep_in_range
(
*
,
prompt_len
:
int
,
start_pos
:
int
,
end_pos
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
)
->
int
:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len
=
max
(
prompt_len
,
0
)
if
prompt_len
==
0
:
return
0
start
=
_clamp_int
(
start_pos
,
0
,
prompt_len
)
end
=
_clamp_int
(
end_pos
,
0
,
prompt_len
)
if
end
<=
start
:
return
0
prefix_len
=
_protected_prefix_len
(
prompt_len
,
protected_prefix
)
suffix_start
=
_protected_suffix_start
(
prompt_len
,
protected_suffix
)
keep_prefix
=
_intersection_len
(
start
,
end
,
0
,
prefix_len
)
keep_suffix
=
_intersection_len
(
start
,
end
,
suffix_start
,
prompt_len
)
overlap
=
_intersection_len
(
start
,
end
,
suffix_start
,
prefix_len
)
kept
=
keep_prefix
+
keep_suffix
-
overlap
if
keep_last_token
:
last
=
prompt_len
-
1
if
start
<=
last
<
end
:
already_kept
=
(
last
<
prefix_len
)
or
(
last
>=
suffix_start
)
if
not
already_kept
:
kept
+=
1
return
kept
def
_count_prompt_candidates_upto
(
*
,
prompt_len
:
int
,
pos
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
)
->
int
:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len
=
max
(
prompt_len
,
0
)
if
prompt_len
==
0
:
return
0
x
=
_clamp_int
(
pos
,
0
,
prompt_len
)
prefix_len
=
_protected_prefix_len
(
prompt_len
,
protected_prefix
)
suffix_start
=
_protected_suffix_start
(
prompt_len
,
protected_suffix
)
mid_end
=
min
(
x
,
suffix_start
)
cand
=
max
(
0
,
mid_end
-
min
(
prefix_len
,
mid_end
))
if
keep_last_token
:
last
=
prompt_len
-
1
if
prefix_len
<=
last
<
mid_end
:
cand
-=
1
return
max
(
cand
,
0
)
def
_candidate_total
(
*
,
prompt_len
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
)
->
int
:
return
_count_prompt_candidates_upto
(
prompt_len
=
prompt_len
,
pos
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
def
_candidate_keep_total
(
*
,
candidate_total
:
int
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
if
candidate_total
<=
0
:
return
0
if
prompt_budget
>=
0
:
return
min
(
prompt_budget
,
candidate_total
)
ratio
=
max
(
0.0
,
min
(
float
(
prompt_ratio
),
1.0
))
keep
=
int
(
math
.
floor
(
candidate_total
*
ratio
+
0.5
))
return
_clamp_int
(
keep
,
0
,
candidate_total
)
def
compute_topk_budget_step
(
*
,
prompt_len
:
int
,
start_pos
:
int
,
end_pos
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total
=
_candidate_total
(
prompt_len
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
if
total
<=
0
:
return
0
total_keep
=
_candidate_keep_total
(
candidate_total
=
total
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
if
total_keep
<=
0
:
return
0
cand_upto_start
=
_count_prompt_candidates_upto
(
prompt_len
=
prompt_len
,
pos
=
start_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
cand_upto_end
=
_count_prompt_candidates_upto
(
prompt_len
=
prompt_len
,
pos
=
end_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
step_total
=
max
(
0
,
cand_upto_end
-
cand_upto_start
)
if
step_total
==
0
:
return
0
bud_upto_start
=
(
total_keep
*
cand_upto_start
)
//
total
bud_upto_end
=
(
total_keep
*
cand_upto_end
)
//
total
step_keep
=
bud_upto_end
-
bud_upto_start
return
_clamp_int
(
step_keep
,
0
,
step_total
)
def
compute_prompt_topk_keep_total
(
*
,
prompt_len
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
"""Compute how many *candidate* prompt tokens to keep in total.
This excludes tokens in the protected prefix/suffix region (and optionally
the last prompt token) which are always kept.
"""
total
=
_candidate_total
(
prompt_len
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
if
total
<=
0
:
return
0
return
_candidate_keep_total
(
candidate_total
=
total
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
def
compute_prompt_keep_len
(
*
,
prompt_len
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
"""Compute total kept prompt tokens after compression (must-keep + Top-K)."""
prompt_len
=
max
(
prompt_len
,
0
)
if
prompt_len
==
0
:
return
0
kept_must_keep
=
count_prompt_must_keep_in_range
(
prompt_len
=
prompt_len
,
start_pos
=
0
,
end_pos
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
kept_topk
=
compute_prompt_topk_keep_total
(
prompt_len
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
return
_clamp_int
(
kept_must_keep
+
kept_topk
,
0
,
prompt_len
)
vllm/v1/kv_compression/slot_mapping.py
0 → 100644
View file @
3da2c829
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
from
vllm.v1.kv_compression.topk_select
import
(
_packed_varlen_coords
,
_topk_keep_mask_and_local_rank
)
def
_dst_slots_from_keep_mask_and_local_rank
(
*
,
keep_mask
:
torch
.
Tensor
,
# [T] bool
local_rank
:
torch
.
Tensor
,
# [T] int64
seq_lens
:
torch
.
Tensor
,
# [B] int32
lengths
:
torch
.
Tensor
,
# [B] int64
req_ids
:
torch
.
Tensor
,
# [T] int64
block_table
:
torch
.
Tensor
,
# [B, max_blocks] int32
block_size
:
int
,
)
->
torch
.
Tensor
:
"""Convert keep_mask/local_rank into a per-token KV destination slot."""
device
=
keep_mask
.
device
T
=
int
(
keep_mask
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
:
return
dst_slots
B
=
int
(
seq_lens
.
numel
())
if
B
==
0
:
return
dst_slots
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv
=
(
seq_lens
[:
B
].
to
(
torch
.
long
)
-
lengths
.
to
(
torch
.
long
)).
clamp_min
(
0
)
base_kv_per_token
=
base_kv
.
index_select
(
0
,
req_ids
)
# [T]
dest_pos
=
base_kv_per_token
+
local_rank
# [T]
dest_block_idx
=
dest_pos
//
block_size
dest_off
=
dest_pos
-
dest_block_idx
*
block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks
=
int
(
block_table
.
shape
[
1
])
dest_block_idx_safe
=
dest_block_idx
.
clamp_
(
0
,
max_blocks
-
1
).
to
(
torch
.
long
)
block_nums
=
block_table
[
req_ids
,
dest_block_idx_safe
]
dest_slot
=
block_nums
.
to
(
torch
.
long
)
*
block_size
+
dest_off
return
torch
.
where
(
keep_mask
,
dest_slot
.
to
(
torch
.
int64
),
dst_slots
)
def
topk_kv_compact_slot_mapping
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
query_start_loc
:
torch
.
Tensor
,
# [B+1]
seq_lens
:
torch
.
Tensor
,
# [B] int32
block_table
:
torch
.
Tensor
,
# [B, max_blocks]
block_size
:
int
,
max_query_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
dst_slots
=
torch
.
full
((
T
,
),
-
1
,
device
=
device
,
dtype
=
torch
.
int64
)
if
T
==
0
or
B
==
0
:
return
dst_slots
starts
,
_
,
lengths
,
req_ids
,
pos_in_req
=
_packed_varlen_coords
(
cu_seqlens
=
query_start_loc
,
total_tokens
=
T
,
)
if
lengths
.
numel
()
==
0
:
return
dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max
=
int
(
max_query_len
)
if
max_query_len
is
not
None
else
int
(
lengths
.
max
().
item
())
if
L_max
<=
0
:
return
dst_slots
keep_mask
,
local_rank
,
_
=
_topk_keep_mask_and_local_rank
(
token_scores
=
token_scores
,
must_keep
=
must_keep
,
topk_budget
=
topk_budget
,
starts
=
starts
,
lengths
=
lengths
,
req_ids
=
req_ids
,
pos_in_req
=
pos_in_req
,
max_len
=
L_max
,
topk_budget_max
=
topk_budget_max
,
)
return
_dst_slots_from_keep_mask_and_local_rank
(
keep_mask
=
keep_mask
,
local_rank
=
local_rank
,
seq_lens
=
seq_lens
[:
B
],
lengths
=
lengths
,
req_ids
=
req_ids
,
block_table
=
block_table
,
block_size
=
int
(
block_size
),
)
def
kv_compaction_dst_rewrite_mapping
(
*
,
dst_slots
:
torch
.
Tensor
,
# [T] int64
src_slots
:
torch
.
Tensor
,
# [T] int64
)
->
torch
.
Tensor
:
"""Filter a dst slot mapping so only moved kept tokens are rewritten.
Non-rewrite tokens are marked as -1, which the cache kernels treat as
padding and skip.
"""
rewrite_mask
=
(
dst_slots
>=
0
)
&
(
dst_slots
!=
src_slots
)
return
torch
.
where
(
rewrite_mask
,
dst_slots
,
-
1
)
vllm/v1/kv_compression/topk_select.py
0 → 100644
View file @
3da2c829
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
from
typing
import
Optional
import
torch
def
_packed_varlen_coords
(
*
,
cu_seqlens
:
torch
.
Tensor
,
# [B+1]
total_tokens
:
int
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute packed varlen segment coordinates.
Returns:
starts: [B] int64, segment start offsets (inclusive)
ends: [B] int64, segment end offsets (exclusive)
lengths: [B] int64, segment lengths (ends - starts)
req_ids: [T] int64, request id for each token in packed [0, T)
pos_in_req: [T] int64, position within its request segment
"""
device
=
cu_seqlens
.
device
B
=
int
(
cu_seqlens
.
numel
()
-
1
)
if
B
<=
0
:
empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
t_empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
return
empty
,
empty
,
empty
,
t_empty
,
t_empty
starts
=
cu_seqlens
[:
B
].
to
(
torch
.
long
)
ends
=
cu_seqlens
[
1
:
B
+
1
].
to
(
torch
.
long
)
lengths
=
ends
-
starts
if
total_tokens
<=
0
:
t_empty
=
torch
.
empty
((
0
,
),
device
=
device
,
dtype
=
torch
.
long
)
return
starts
,
ends
,
lengths
,
t_empty
,
t_empty
token_idx
=
torch
.
arange
(
total_tokens
,
device
=
device
,
dtype
=
torch
.
long
)
req_ids
=
torch
.
bucketize
(
token_idx
,
ends
,
right
=
True
)
# [T]
start_per_token
=
starts
.
index_select
(
0
,
req_ids
)
pos_in_req
=
token_idx
-
start_per_token
return
starts
,
ends
,
lengths
,
req_ids
,
pos_in_req
def
_topk_keep_mask_and_local_rank
(
*
,
token_scores
:
Optional
[
torch
.
Tensor
],
# [T] float32
must_keep
:
torch
.
Tensor
,
# [T] bool
topk_budget
:
torch
.
Tensor
,
# [B] int32
starts
:
torch
.
Tensor
,
# [B] int64
lengths
:
torch
.
Tensor
,
# [B] int64
req_ids
:
torch
.
Tensor
,
# [T] int64
pos_in_req
:
torch
.
Tensor
,
# [T] int64
max_len
:
Optional
[
int
]
=
None
,
topk_budget_max
:
Optional
[
int
]
=
None
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute keep_mask/local_rank for token-shared Top-K selection.
Returns:
keep_mask: [T] bool, selected tokens (includes must_keep)
local_rank: [T] int64, rank among kept tokens within each request
keep_len: [B] int32, number of kept tokens per request
"""
device
=
must_keep
.
device
T
=
int
(
must_keep
.
numel
())
B
=
int
(
topk_budget
.
numel
())
keep_mask
=
must_keep
.
clone
()
if
T
==
0
or
B
==
0
:
local_rank
=
torch
.
empty
((
T
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
int32
)
return
keep_mask
,
local_rank
,
keep_len
if
max_len
is
None
:
L_max
=
int
(
lengths
.
max
().
item
())
if
lengths
.
numel
()
>
0
else
0
else
:
L_max
=
int
(
max_len
)
if
L_max
<
0
:
L_max
=
0
must_keep_counts
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
must_keep_counts
.
scatter_add_
(
0
,
req_ids
,
must_keep
.
to
(
torch
.
long
))
cand_counts
=
(
lengths
.
to
(
torch
.
long
)
-
must_keep_counts
).
clamp_min
(
0
)
k_eff
=
torch
.
minimum
(
topk_budget
.
to
(
torch
.
long
).
clamp_min
(
0
),
cand_counts
)
# CPU-known bound avoids a device->host sync; clamp for safety.
if
topk_budget_max
is
None
:
k_max
=
int
(
k_eff
.
max
().
item
())
if
k_eff
.
numel
()
>
0
else
0
else
:
k_max
=
int
(
topk_budget_max
)
if
k_max
<
0
:
k_max
=
0
if
k_max
>
L_max
:
k_max
=
L_max
if
k_max
>
0
:
if
token_scores
is
None
:
raise
ValueError
(
"token_scores must be provided when k_max > 0."
)
masked_scores
=
token_scores
.
to
(
torch
.
float32
).
masked_fill
(
must_keep
,
float
(
"-inf"
))
scores_flat
=
masked_scores
.
new_full
((
B
*
L_max
,
),
float
(
"-inf"
))
linear
=
req_ids
*
L_max
+
pos_in_req
scores_flat
[
linear
]
=
masked_scores
scores
=
scores_flat
.
view
(
B
,
L_max
)
topk_pos
=
torch
.
topk
(
scores
,
k
=
k_max
,
dim
=
1
).
indices
# [B, k_max]
col_mask
=
torch
.
arange
(
k_max
,
device
=
device
).
unsqueeze
(
0
)
<
k_eff
.
unsqueeze
(
1
)
global_sel
=
starts
.
unsqueeze
(
1
)
+
topk_pos
.
to
(
torch
.
long
)
# [B,k_max]
flat_idx
=
global_sel
.
reshape
(
-
1
).
clamp_
(
0
,
T
-
1
)
flat_val
=
col_mask
.
reshape
(
-
1
).
to
(
torch
.
int32
)
tmp
=
torch
.
zeros
((
T
,
),
device
=
device
,
dtype
=
torch
.
int32
)
tmp
.
scatter_add_
(
0
,
flat_idx
,
flat_val
)
keep_mask
|=
tmp
>
0
keep_len
=
torch
.
zeros
((
B
,
),
device
=
device
,
dtype
=
torch
.
long
)
keep_len
.
scatter_add_
(
0
,
req_ids
,
keep_mask
.
to
(
torch
.
long
))
# Stable, order-preserving local rank using segment-local prefix sums.
keep_prefix
=
torch
.
cumsum
(
keep_mask
.
to
(
torch
.
long
),
dim
=
0
)
# [T]
start_minus_1
=
(
starts
-
1
).
clamp_min
(
0
)
prefix_before_all
=
keep_prefix
.
index_select
(
0
,
start_minus_1
)
prefix_before
=
torch
.
where
(
starts
>
0
,
prefix_before_all
,
torch
.
zeros_like
(
prefix_before_all
))
# [B]
prefix_before_per_token
=
prefix_before
.
index_select
(
0
,
req_ids
)
# [T]
local_rank
=
keep_prefix
-
prefix_before_per_token
-
1
# [T]
return
keep_mask
,
local_rank
,
keep_len
.
to
(
torch
.
int32
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment