Commit d3acd4a5 authored by laibao's avatar laibao
Browse files

feat: kvpress新增 KV 压缩预算计算模块

parent ade2749c
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .budget import ( # noqa: F401
compute_topk_budget_step,
count_prompt_must_keep_in_range,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
def _clamp_int(value: int, lo: int, hi: int) -> int:
if value < lo:
return lo
if value > hi:
return hi
return value
def _intersection_len(a0: int, a1: int, b0: int, b1: int) -> int:
start = a0 if a0 > b0 else b0
end = a1 if a1 < b1 else b1
return max(0, end - start)
def _protected_prefix_len(prompt_len: int, protected_prefix: int) -> int:
return min(max(protected_prefix, 0), max(prompt_len, 0))
def _protected_suffix_start(prompt_len: int, protected_suffix: int) -> int:
prompt_len = max(prompt_len, 0)
suffix = min(max(protected_suffix, 0), prompt_len)
return prompt_len - suffix
def count_prompt_must_keep_in_range(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
start = _clamp_int(start_pos, 0, prompt_len)
end = _clamp_int(end_pos, 0, prompt_len)
if end <= start:
return 0
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
keep_prefix = _intersection_len(start, end, 0, prefix_len)
keep_suffix = _intersection_len(start, end, suffix_start, prompt_len)
overlap = _intersection_len(start, end, suffix_start, prefix_len)
kept = keep_prefix + keep_suffix - overlap
if keep_last_token:
last = prompt_len - 1
if start <= last < end:
already_kept = (last < prefix_len) or (last >= suffix_start)
if not already_kept:
kept += 1
return kept
def _count_prompt_candidates_upto(
*,
prompt_len: int,
pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
x = _clamp_int(pos, 0, prompt_len)
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
mid_end = min(x, suffix_start)
cand = max(0, mid_end - min(prefix_len, mid_end))
if keep_last_token:
last = prompt_len - 1
if prefix_len <= last < mid_end:
cand -= 1
return max(cand, 0)
def _candidate_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
return _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
def _candidate_keep_total(
*,
candidate_total: int,
prompt_ratio: float,
prompt_budget: int,
) -> int:
if candidate_total <= 0:
return 0
if prompt_budget >= 0:
return min(prompt_budget, candidate_total)
ratio = max(0.0, min(float(prompt_ratio), 1.0))
keep = int(math.floor(candidate_total * ratio + 0.5))
return _clamp_int(keep, 0, candidate_total)
def compute_topk_budget_step(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
total_keep = _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
if total_keep <= 0:
return 0
cand_upto_start = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=start_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
cand_upto_end = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
step_total = max(0, cand_upto_end - cand_upto_start)
if step_total == 0:
return 0
bud_upto_start = (total_keep * cand_upto_start) // total
bud_upto_end = (total_keep * cand_upto_end) // total
step_keep = bud_upto_end - bud_upto_start
return _clamp_int(step_keep, 0, step_total)
def compute_prompt_topk_keep_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many *candidate* prompt tokens to keep in total.
This excludes tokens in the protected prefix/suffix region (and optionally
the last prompt token) which are always kept.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
return _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
def compute_prompt_keep_len(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute total kept prompt tokens after compression (must-keep + Top-K)."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
kept_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_len,
start_pos=0,
end_pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
kept_topk = compute_prompt_topk_keep_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
return _clamp_int(kept_must_keep + kept_topk, 0, prompt_len)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment