Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d3acd4a5
Commit
d3acd4a5
authored
Jan 20, 2026
by
laibao
Browse files
feat: kvpress新增 KV 压缩预算计算模块
parent
ade2749c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
256 additions
and
0 deletions
+256
-0
vllm/v1/kv_compression/__init__.py
vllm/v1/kv_compression/__init__.py
+8
-0
vllm/v1/kv_compression/budget.py
vllm/v1/kv_compression/budget.py
+248
-0
No files found.
vllm/v1/kv_compression/__init__.py
0 → 100644
View file @
d3acd4a5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
.budget
import
(
# noqa: F401
compute_topk_budget_step
,
count_prompt_must_keep_in_range
,
)
vllm/v1/kv_compression/budget.py
0 → 100644
View file @
d3acd4a5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from
__future__
import
annotations
import
math
def
_clamp_int
(
value
:
int
,
lo
:
int
,
hi
:
int
)
->
int
:
if
value
<
lo
:
return
lo
if
value
>
hi
:
return
hi
return
value
def
_intersection_len
(
a0
:
int
,
a1
:
int
,
b0
:
int
,
b1
:
int
)
->
int
:
start
=
a0
if
a0
>
b0
else
b0
end
=
a1
if
a1
<
b1
else
b1
return
max
(
0
,
end
-
start
)
def
_protected_prefix_len
(
prompt_len
:
int
,
protected_prefix
:
int
)
->
int
:
return
min
(
max
(
protected_prefix
,
0
),
max
(
prompt_len
,
0
))
def
_protected_suffix_start
(
prompt_len
:
int
,
protected_suffix
:
int
)
->
int
:
prompt_len
=
max
(
prompt_len
,
0
)
suffix
=
min
(
max
(
protected_suffix
,
0
),
prompt_len
)
return
prompt_len
-
suffix
def
count_prompt_must_keep_in_range
(
*
,
prompt_len
:
int
,
start_pos
:
int
,
end_pos
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
)
->
int
:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len
=
max
(
prompt_len
,
0
)
if
prompt_len
==
0
:
return
0
start
=
_clamp_int
(
start_pos
,
0
,
prompt_len
)
end
=
_clamp_int
(
end_pos
,
0
,
prompt_len
)
if
end
<=
start
:
return
0
prefix_len
=
_protected_prefix_len
(
prompt_len
,
protected_prefix
)
suffix_start
=
_protected_suffix_start
(
prompt_len
,
protected_suffix
)
keep_prefix
=
_intersection_len
(
start
,
end
,
0
,
prefix_len
)
keep_suffix
=
_intersection_len
(
start
,
end
,
suffix_start
,
prompt_len
)
overlap
=
_intersection_len
(
start
,
end
,
suffix_start
,
prefix_len
)
kept
=
keep_prefix
+
keep_suffix
-
overlap
if
keep_last_token
:
last
=
prompt_len
-
1
if
start
<=
last
<
end
:
already_kept
=
(
last
<
prefix_len
)
or
(
last
>=
suffix_start
)
if
not
already_kept
:
kept
+=
1
return
kept
def
_count_prompt_candidates_upto
(
*
,
prompt_len
:
int
,
pos
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
)
->
int
:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len
=
max
(
prompt_len
,
0
)
if
prompt_len
==
0
:
return
0
x
=
_clamp_int
(
pos
,
0
,
prompt_len
)
prefix_len
=
_protected_prefix_len
(
prompt_len
,
protected_prefix
)
suffix_start
=
_protected_suffix_start
(
prompt_len
,
protected_suffix
)
mid_end
=
min
(
x
,
suffix_start
)
cand
=
max
(
0
,
mid_end
-
min
(
prefix_len
,
mid_end
))
if
keep_last_token
:
last
=
prompt_len
-
1
if
prefix_len
<=
last
<
mid_end
:
cand
-=
1
return
max
(
cand
,
0
)
def
_candidate_total
(
*
,
prompt_len
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
)
->
int
:
return
_count_prompt_candidates_upto
(
prompt_len
=
prompt_len
,
pos
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
def
_candidate_keep_total
(
*
,
candidate_total
:
int
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
if
candidate_total
<=
0
:
return
0
if
prompt_budget
>=
0
:
return
min
(
prompt_budget
,
candidate_total
)
ratio
=
max
(
0.0
,
min
(
float
(
prompt_ratio
),
1.0
))
keep
=
int
(
math
.
floor
(
candidate_total
*
ratio
+
0.5
))
return
_clamp_int
(
keep
,
0
,
candidate_total
)
def
compute_topk_budget_step
(
*
,
prompt_len
:
int
,
start_pos
:
int
,
end_pos
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total
=
_candidate_total
(
prompt_len
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
if
total
<=
0
:
return
0
total_keep
=
_candidate_keep_total
(
candidate_total
=
total
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
if
total_keep
<=
0
:
return
0
cand_upto_start
=
_count_prompt_candidates_upto
(
prompt_len
=
prompt_len
,
pos
=
start_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
cand_upto_end
=
_count_prompt_candidates_upto
(
prompt_len
=
prompt_len
,
pos
=
end_pos
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
step_total
=
max
(
0
,
cand_upto_end
-
cand_upto_start
)
if
step_total
==
0
:
return
0
bud_upto_start
=
(
total_keep
*
cand_upto_start
)
//
total
bud_upto_end
=
(
total_keep
*
cand_upto_end
)
//
total
step_keep
=
bud_upto_end
-
bud_upto_start
return
_clamp_int
(
step_keep
,
0
,
step_total
)
def
compute_prompt_topk_keep_total
(
*
,
prompt_len
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
"""Compute how many *candidate* prompt tokens to keep in total.
This excludes tokens in the protected prefix/suffix region (and optionally
the last prompt token) which are always kept.
"""
total
=
_candidate_total
(
prompt_len
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
if
total
<=
0
:
return
0
return
_candidate_keep_total
(
candidate_total
=
total
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
def
compute_prompt_keep_len
(
*
,
prompt_len
:
int
,
protected_prefix
:
int
,
protected_suffix
:
int
,
keep_last_token
:
bool
,
prompt_ratio
:
float
,
prompt_budget
:
int
,
)
->
int
:
"""Compute total kept prompt tokens after compression (must-keep + Top-K)."""
prompt_len
=
max
(
prompt_len
,
0
)
if
prompt_len
==
0
:
return
0
kept_must_keep
=
count_prompt_must_keep_in_range
(
prompt_len
=
prompt_len
,
start_pos
=
0
,
end_pos
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
)
kept_topk
=
compute_prompt_topk_keep_total
(
prompt_len
=
prompt_len
,
protected_prefix
=
protected_prefix
,
protected_suffix
=
protected_suffix
,
keep_last_token
=
keep_last_token
,
prompt_ratio
=
prompt_ratio
,
prompt_budget
=
prompt_budget
,
)
return
_clamp_int
(
kept_must_keep
+
kept_topk
,
0
,
prompt_len
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment