# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from __future__ import annotations import math def _clamp_int(value: int, lo: int, hi: int) -> int: if value < lo: return lo if value > hi: return hi return value def _intersection_len(a0: int, a1: int, b0: int, b1: int) -> int: start = a0 if a0 > b0 else b0 end = a1 if a1 < b1 else b1 return max(0, end - start) def _protected_prefix_len(prompt_len: int, protected_prefix: int) -> int: return min(max(protected_prefix, 0), max(prompt_len, 0)) def _protected_suffix_start(prompt_len: int, protected_suffix: int) -> int: prompt_len = max(prompt_len, 0) suffix = min(max(protected_suffix, 0), prompt_len) return prompt_len - suffix def count_prompt_must_keep_in_range( *, prompt_len: int, start_pos: int, end_pos: int, protected_prefix: int, protected_suffix: int, keep_last_token: bool, ) -> int: """Count prompt tokens in [start_pos, end_pos) that are always kept.""" prompt_len = max(prompt_len, 0) if prompt_len == 0: return 0 start = _clamp_int(start_pos, 0, prompt_len) end = _clamp_int(end_pos, 0, prompt_len) if end <= start: return 0 prefix_len = _protected_prefix_len(prompt_len, protected_prefix) suffix_start = _protected_suffix_start(prompt_len, protected_suffix) keep_prefix = _intersection_len(start, end, 0, prefix_len) keep_suffix = _intersection_len(start, end, suffix_start, prompt_len) overlap = _intersection_len(start, end, suffix_start, prefix_len) kept = keep_prefix + keep_suffix - overlap if keep_last_token: last = prompt_len - 1 if start <= last < end: already_kept = (last < prefix_len) or (last >= suffix_start) if not already_kept: kept += 1 return kept def _count_prompt_candidates_upto( *, prompt_len: int, pos: int, protected_prefix: int, protected_suffix: int, keep_last_token: bool, ) -> int: """Count prompt candidates in [0, pos) eligible for Top-K selection.""" prompt_len = max(prompt_len, 0) if prompt_len == 0: return 0 x = _clamp_int(pos, 0, prompt_len) prefix_len = _protected_prefix_len(prompt_len, protected_prefix) suffix_start = _protected_suffix_start(prompt_len, protected_suffix) mid_end = min(x, suffix_start) cand = max(0, mid_end - min(prefix_len, mid_end)) if keep_last_token: last = prompt_len - 1 if prefix_len <= last < mid_end: cand -= 1 return max(cand, 0) def _candidate_total( *, prompt_len: int, protected_prefix: int, protected_suffix: int, keep_last_token: bool, ) -> int: return _count_prompt_candidates_upto( prompt_len=prompt_len, pos=prompt_len, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last_token, ) def _candidate_keep_total( *, candidate_total: int, prompt_ratio: float, prompt_budget: int, ) -> int: if candidate_total <= 0: return 0 if prompt_budget >= 0: return min(prompt_budget, candidate_total) ratio = max(0.0, min(float(prompt_ratio), 1.0)) keep = int(math.floor(candidate_total * ratio + 0.5)) return _clamp_int(keep, 0, candidate_total) def compute_topk_budget_step( *, prompt_len: int, start_pos: int, end_pos: int, protected_prefix: int, protected_suffix: int, keep_last_token: bool, prompt_ratio: float, prompt_budget: int, ) -> int: """Compute how many prompt candidate tokens to select for this step. The budget applies to the *non-protected* prompt region and is distributed across multiple prefill steps using a prefix-proportional rule: budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total) The step's budget is the delta between its end and start positions. """ total = _candidate_total( prompt_len=prompt_len, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last_token, ) if total <= 0: return 0 total_keep = _candidate_keep_total( candidate_total=total, prompt_ratio=prompt_ratio, prompt_budget=prompt_budget, ) if total_keep <= 0: return 0 cand_upto_start = _count_prompt_candidates_upto( prompt_len=prompt_len, pos=start_pos, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last_token, ) cand_upto_end = _count_prompt_candidates_upto( prompt_len=prompt_len, pos=end_pos, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last_token, ) step_total = max(0, cand_upto_end - cand_upto_start) if step_total == 0: return 0 bud_upto_start = (total_keep * cand_upto_start) // total bud_upto_end = (total_keep * cand_upto_end) // total step_keep = bud_upto_end - bud_upto_start return _clamp_int(step_keep, 0, step_total) def compute_prompt_topk_keep_total( *, prompt_len: int, protected_prefix: int, protected_suffix: int, keep_last_token: bool, prompt_ratio: float, prompt_budget: int, ) -> int: """Compute how many *candidate* prompt tokens to keep in total. This excludes tokens in the protected prefix/suffix region (and optionally the last prompt token) which are always kept. """ total = _candidate_total( prompt_len=prompt_len, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last_token, ) if total <= 0: return 0 return _candidate_keep_total( candidate_total=total, prompt_ratio=prompt_ratio, prompt_budget=prompt_budget, ) def compute_prompt_keep_len( *, prompt_len: int, protected_prefix: int, protected_suffix: int, keep_last_token: bool, prompt_ratio: float, prompt_budget: int, ) -> int: """Compute total kept prompt tokens after compression (must-keep + Top-K).""" prompt_len = max(prompt_len, 0) if prompt_len == 0: return 0 kept_must_keep = count_prompt_must_keep_in_range( prompt_len=prompt_len, start_pos=0, end_pos=prompt_len, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last_token, ) kept_topk = compute_prompt_topk_keep_total( prompt_len=prompt_len, protected_prefix=protected_prefix, protected_suffix=protected_suffix, keep_last_token=keep_last_token, prompt_ratio=prompt_ratio, prompt_budget=prompt_budget, ) return _clamp_int(kept_must_keep + kept_topk, 0, prompt_len)