"tests/fault_tolerance/vscode:/vscode.git/clone" did not exist on "8ed69ea2f8f73a00512bfe15045e7803bb9b63cb"
Commit 3da2c829 authored by laibao's avatar laibao
Browse files

feat(kvpress): 新增 Top-K budget 与选择工具

parent d41ca128
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .budget import ( # noqa: F401
compute_topk_budget_step,
count_prompt_must_keep_in_range,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
def _clamp_int(value: int, lo: int, hi: int) -> int:
if value < lo:
return lo
if value > hi:
return hi
return value
def _intersection_len(a0: int, a1: int, b0: int, b1: int) -> int:
start = a0 if a0 > b0 else b0
end = a1 if a1 < b1 else b1
return max(0, end - start)
def _protected_prefix_len(prompt_len: int, protected_prefix: int) -> int:
return min(max(protected_prefix, 0), max(prompt_len, 0))
def _protected_suffix_start(prompt_len: int, protected_suffix: int) -> int:
prompt_len = max(prompt_len, 0)
suffix = min(max(protected_suffix, 0), prompt_len)
return prompt_len - suffix
def count_prompt_must_keep_in_range(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
start = _clamp_int(start_pos, 0, prompt_len)
end = _clamp_int(end_pos, 0, prompt_len)
if end <= start:
return 0
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
keep_prefix = _intersection_len(start, end, 0, prefix_len)
keep_suffix = _intersection_len(start, end, suffix_start, prompt_len)
overlap = _intersection_len(start, end, suffix_start, prefix_len)
kept = keep_prefix + keep_suffix - overlap
if keep_last_token:
last = prompt_len - 1
if start <= last < end:
already_kept = (last < prefix_len) or (last >= suffix_start)
if not already_kept:
kept += 1
return kept
def _count_prompt_candidates_upto(
*,
prompt_len: int,
pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
x = _clamp_int(pos, 0, prompt_len)
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
mid_end = min(x, suffix_start)
cand = max(0, mid_end - min(prefix_len, mid_end))
if keep_last_token:
last = prompt_len - 1
if prefix_len <= last < mid_end:
cand -= 1
return max(cand, 0)
def _candidate_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
return _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
def _candidate_keep_total(
*,
candidate_total: int,
prompt_ratio: float,
prompt_budget: int,
) -> int:
if candidate_total <= 0:
return 0
if prompt_budget >= 0:
return min(prompt_budget, candidate_total)
ratio = max(0.0, min(float(prompt_ratio), 1.0))
keep = int(math.floor(candidate_total * ratio + 0.5))
return _clamp_int(keep, 0, candidate_total)
def compute_topk_budget_step(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
total_keep = _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
if total_keep <= 0:
return 0
cand_upto_start = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=start_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
cand_upto_end = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
step_total = max(0, cand_upto_end - cand_upto_start)
if step_total == 0:
return 0
bud_upto_start = (total_keep * cand_upto_start) // total
bud_upto_end = (total_keep * cand_upto_end) // total
step_keep = bud_upto_end - bud_upto_start
return _clamp_int(step_keep, 0, step_total)
def compute_prompt_topk_keep_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many *candidate* prompt tokens to keep in total.
This excludes tokens in the protected prefix/suffix region (and optionally
the last prompt token) which are always kept.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
return _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
def compute_prompt_keep_len(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute total kept prompt tokens after compression (must-keep + Top-K)."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
kept_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_len,
start_pos=0,
end_pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
kept_topk = compute_prompt_topk_keep_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
return _clamp_int(kept_must_keep + kept_topk, 0, prompt_len)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import torch
from vllm.v1.kv_compression.topk_select import (_packed_varlen_coords,
_topk_keep_mask_and_local_rank)
def _dst_slots_from_keep_mask_and_local_rank(
*,
keep_mask: torch.Tensor, # [T] bool
local_rank: torch.Tensor, # [T] int64
seq_lens: torch.Tensor, # [B] int32
lengths: torch.Tensor, # [B] int64
req_ids: torch.Tensor, # [T] int64
block_table: torch.Tensor, # [B, max_blocks] int32
block_size: int,
) -> torch.Tensor:
"""Convert keep_mask/local_rank into a per-token KV destination slot."""
device = keep_mask.device
T = int(keep_mask.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0:
return dst_slots
B = int(seq_lens.numel())
if B == 0:
return dst_slots
# Base KV cache position for this step (i.e., KV length before writing this
# scheduled segment). With KV compression enabled, seq_lens is derived from
# num_kv_tokens + scheduled_len, so base_kv == seq_lens - scheduled_len.
base_kv = (seq_lens[:B].to(torch.long) - lengths.to(torch.long)).clamp_min(0)
base_kv_per_token = base_kv.index_select(0, req_ids) # [T]
dest_pos = base_kv_per_token + local_rank # [T]
dest_block_idx = dest_pos // block_size
dest_off = dest_pos - dest_block_idx * block_size
# Safe indexing for dropped tokens (ignored by keep_mask anyway).
max_blocks = int(block_table.shape[1])
dest_block_idx_safe = dest_block_idx.clamp_(0, max_blocks - 1).to(torch.long)
block_nums = block_table[req_ids, dest_block_idx_safe]
dest_slot = block_nums.to(torch.long) * block_size + dest_off
return torch.where(keep_mask, dest_slot.to(torch.int64), dst_slots)
def topk_kv_compact_slot_mapping(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
query_start_loc: torch.Tensor, # [B+1]
seq_lens: torch.Tensor, # [B] int32
block_table: torch.Tensor, # [B, max_blocks]
block_size: int,
max_query_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> torch.Tensor:
"""Build a per-token destination slot mapping for KV compaction.
Returns a tensor `dst_slots` of shape [T] where:
- `dst_slots[i] >= 0` indicates token i should be kept and rewritten to
that KV cache slot.
- `dst_slots[i] == -1` indicates token i is dropped after the step.
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
dst_slots = torch.full((T, ), -1, device=device, dtype=torch.int64)
if T == 0 or B == 0:
return dst_slots
starts, _, lengths, req_ids, pos_in_req = _packed_varlen_coords(
cu_seqlens=query_start_loc,
total_tokens=T,
)
if lengths.numel() == 0:
return dst_slots
# Prefer the CPU-known max query length (piecewise graph), to avoid
# device->host synchronization.
L_max = int(max_query_len) if max_query_len is not None else int(
lengths.max().item())
if L_max <= 0:
return dst_slots
keep_mask, local_rank, _ = _topk_keep_mask_and_local_rank(
token_scores=token_scores,
must_keep=must_keep,
topk_budget=topk_budget,
starts=starts,
lengths=lengths,
req_ids=req_ids,
pos_in_req=pos_in_req,
max_len=L_max,
topk_budget_max=topk_budget_max,
)
return _dst_slots_from_keep_mask_and_local_rank(
keep_mask=keep_mask,
local_rank=local_rank,
seq_lens=seq_lens[:B],
lengths=lengths,
req_ids=req_ids,
block_table=block_table,
block_size=int(block_size),
)
def kv_compaction_dst_rewrite_mapping(
*,
dst_slots: torch.Tensor, # [T] int64
src_slots: torch.Tensor, # [T] int64
) -> torch.Tensor:
"""Filter a dst slot mapping so only moved kept tokens are rewritten.
Non-rewrite tokens are marked as -1, which the cache kernels treat as
padding and skip.
"""
rewrite_mask = (dst_slots >= 0) & (dst_slots != src_slots)
return torch.where(rewrite_mask, dst_slots, -1)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional
import torch
def _packed_varlen_coords(
*,
cu_seqlens: torch.Tensor, # [B+1]
total_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute packed varlen segment coordinates.
Returns:
starts: [B] int64, segment start offsets (inclusive)
ends: [B] int64, segment end offsets (exclusive)
lengths: [B] int64, segment lengths (ends - starts)
req_ids: [T] int64, request id for each token in packed [0, T)
pos_in_req: [T] int64, position within its request segment
"""
device = cu_seqlens.device
B = int(cu_seqlens.numel() - 1)
if B <= 0:
empty = torch.empty((0, ), device=device, dtype=torch.long)
t_empty = torch.empty((0, ), device=device, dtype=torch.long)
return empty, empty, empty, t_empty, t_empty
starts = cu_seqlens[:B].to(torch.long)
ends = cu_seqlens[1:B + 1].to(torch.long)
lengths = ends - starts
if total_tokens <= 0:
t_empty = torch.empty((0, ), device=device, dtype=torch.long)
return starts, ends, lengths, t_empty, t_empty
token_idx = torch.arange(total_tokens, device=device, dtype=torch.long)
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token
return starts, ends, lengths, req_ids, pos_in_req
def _topk_keep_mask_and_local_rank(
*,
token_scores: Optional[torch.Tensor], # [T] float32
must_keep: torch.Tensor, # [T] bool
topk_budget: torch.Tensor, # [B] int32
starts: torch.Tensor, # [B] int64
lengths: torch.Tensor, # [B] int64
req_ids: torch.Tensor, # [T] int64
pos_in_req: torch.Tensor, # [T] int64
max_len: Optional[int] = None,
topk_budget_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute keep_mask/local_rank for token-shared Top-K selection.
Returns:
keep_mask: [T] bool, selected tokens (includes must_keep)
local_rank: [T] int64, rank among kept tokens within each request
keep_len: [B] int32, number of kept tokens per request
"""
device = must_keep.device
T = int(must_keep.numel())
B = int(topk_budget.numel())
keep_mask = must_keep.clone()
if T == 0 or B == 0:
local_rank = torch.empty((T, ), device=device, dtype=torch.long)
keep_len = torch.zeros((B, ), device=device, dtype=torch.int32)
return keep_mask, local_rank, keep_len
if max_len is None:
L_max = int(lengths.max().item()) if lengths.numel() > 0 else 0
else:
L_max = int(max_len)
if L_max < 0:
L_max = 0
must_keep_counts = torch.zeros((B, ), device=device, dtype=torch.long)
must_keep_counts.scatter_add_(0, req_ids, must_keep.to(torch.long))
cand_counts = (lengths.to(torch.long) - must_keep_counts).clamp_min(0)
k_eff = torch.minimum(topk_budget.to(torch.long).clamp_min(0), cand_counts)
# CPU-known bound avoids a device->host sync; clamp for safety.
if topk_budget_max is None:
k_max = int(k_eff.max().item()) if k_eff.numel() > 0 else 0
else:
k_max = int(topk_budget_max)
if k_max < 0:
k_max = 0
if k_max > L_max:
k_max = L_max
if k_max > 0:
if token_scores is None:
raise ValueError("token_scores must be provided when k_max > 0.")
masked_scores = token_scores.to(torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
col_mask = torch.arange(k_max,
device=device).unsqueeze(0) < k_eff.unsqueeze(1)
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B,k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
keep_len = torch.zeros((B, ), device=device, dtype=torch.long)
keep_len.scatter_add_(0, req_ids, keep_mask.to(torch.long))
# Stable, order-preserving local rank using segment-local prefix sums.
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1)
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all)) # [B]
prefix_before_per_token = prefix_before.index_select(0, req_ids) # [T]
local_rank = keep_prefix - prefix_before_per_token - 1 # [T]
return keep_mask, local_rank, keep_len.to(torch.int32)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment