Commit a9ebf337 authored by laibao's avatar laibao
Browse files

feat: kvpress flash_attn(scheme 3)生成 prompt-end payload

parent b6a27380
......@@ -8,6 +8,7 @@ import numpy as np
import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context
from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
......@@ -646,6 +647,33 @@ class FlashAttentionImpl(AttentionImpl):
layer._q_scale)
query = query.reshape((num_tokens, num_heads, head_size))
# Scheme 3 (chunked prefill): on the last prompt chunk, compute global
# prompt indices (score/topk) and cache them in the forward context for
# the model runner to consume before the first decode step.
if (envs.VLLM_ENABLE_KV_COMPRESSION
and self.kv_sharing_target_layer_name is None
and attn_metadata.kv_compression_prompt_end is not None
and attn_metadata.kv_compression_prompt_lens is not None
and attn_metadata.kv_compression_prompt_topk_keep is not None):
forward_context = get_forward_context()
payload = getattr(forward_context, "_kv_compression_prompt_payload",
None)
if payload is None:
payload = _compute_prompt_end_indices(
query=query[:num_actual_tokens],
key_cache=key_cache,
query_start_loc=attn_metadata.query_start_loc,
block_table=attn_metadata.block_table,
prompt_end=attn_metadata.kv_compression_prompt_end,
prompt_lens=attn_metadata.kv_compression_prompt_lens,
topk_keep=attn_metadata.kv_compression_prompt_topk_keep,
topk_keep_max=attn_metadata.kv_compression_prompt_topk_keep_max,
sm_scale=self.scale,
)
if payload is not None:
setattr(forward_context, "_kv_compression_prompt_payload",
payload)
# Compute attention and update output up to `num_actual_tokens`.
use_local_attn = \
(self.use_irope and attn_metadata.local_attn_metadata is not None)
......@@ -781,6 +809,251 @@ class FlashAttentionImpl(AttentionImpl):
return output
def _prompt_end_topk_keep_indices(
*,
token_scores: torch.Tensor, # [T] float32
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32 (candidates only)
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
topk_keep_max: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select kept prompt indices (ascending) for one-shot compaction.
Returns:
idx_sorted: [B, K_max] int32, per-request kept token indices (0..L-1)
keep_len: [B] int32, number of kept tokens per request
"""
device = token_scores.device
B = int(prompt_lens.numel())
if B == 0:
empty = torch.empty((0, 0), device=device, dtype=torch.int32)
return empty, torch.empty((0, ), device=device, dtype=torch.int32)
prompt_lens_i64 = prompt_lens.to(torch.long)
cu = torch.zeros((B + 1, ), device=device, dtype=torch.long)
cu[1:] = torch.cumsum(prompt_lens_i64, dim=0)
starts = cu[:B]
ends = cu[1:]
T = int(token_scores.numel())
if T == 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, torch.zeros((B, ), device=device, dtype=torch.int32)
token_idx = torch.arange(T, device=device, dtype=torch.long)
req_ids = torch.bucketize(token_idx, ends, right=True) # [T]
start_per_token = starts.index_select(0, req_ids)
pos_in_req = token_idx - start_per_token # [T]
# Must-keep mask (protected prefix/suffix + optional last prompt token).
prefix_len = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_prefix, 0))
suffix = torch.clamp(prompt_lens_i64,
min=0).clamp_max(max(protected_suffix, 0))
suffix_start = (prompt_lens_i64 - suffix).clamp_min(0)
prefix_len_t = prefix_len.index_select(0, req_ids)
suffix_start_t = suffix_start.index_select(0, req_ids)
must_keep = (pos_in_req < prefix_len_t) | (pos_in_req >= suffix_start_t)
if keep_last_token:
last = (prompt_lens_i64 - 1).clamp_min(0)
last_t = last.index_select(0, req_ids)
must_keep |= pos_in_req == last_t
cand_counts = torch.zeros((B, ), device=device, dtype=torch.long)
cand_counts.scatter_add_(0, req_ids, (~must_keep).to(torch.long))
k_eff = torch.minimum(topk_keep.to(torch.long).clamp_min(0), cand_counts)
# CPU-known bound avoids a device->host sync; clamp for safety.
if topk_keep_max is None:
k_max = int(k_eff.max().item())
else:
k_max = int(topk_keep_max)
if k_max < 0:
k_max = 0
keep_mask = must_keep.clone()
if k_max > 0:
L_max = int(prompt_lens_i64.max().item())
masked_scores = token_scores.to(torch.float32).masked_fill(
must_keep, float("-inf"))
scores_flat = masked_scores.new_full((B * L_max, ), float("-inf"))
linear = req_ids * L_max + pos_in_req
scores_flat[linear] = masked_scores
scores = scores_flat.view(B, L_max)
topk_pos = torch.topk(scores, k=k_max, dim=1).indices # [B, k_max]
col_mask = (torch.arange(k_max, device=device).unsqueeze(0) <
k_eff.unsqueeze(1))
global_sel = starts.unsqueeze(1) + topk_pos.to(torch.long) # [B,k_max]
flat_idx = global_sel.reshape(-1).clamp_(0, T - 1)
flat_val = col_mask.reshape(-1).to(torch.int32)
tmp = torch.zeros((T, ), device=device, dtype=torch.int32)
tmp.scatter_add_(0, flat_idx, flat_val)
keep_mask |= tmp > 0
keep_len = torch.zeros((B, ), device=device, dtype=torch.long)
keep_len.scatter_add_(0, req_ids, keep_mask.to(torch.long))
keep_max_len = int(keep_len.max().item()) if B > 0 else 0
if keep_max_len <= 0:
empty = torch.empty((B, 0), device=device, dtype=torch.int32)
return empty, keep_len.to(torch.int32)
# Stable, order-preserving index list using segment-local ranks.
keep_prefix = torch.cumsum(keep_mask.to(torch.long), dim=0) # [T]
start_minus_1 = (starts - 1).clamp_min(0)
prefix_before_all = keep_prefix.index_select(0, start_minus_1)
prefix_before = torch.where(starts > 0, prefix_before_all,
torch.zeros_like(prefix_before_all))
prefix_before_t = prefix_before.index_select(0, req_ids)
local_rank = keep_prefix - prefix_before_t - 1 # [T]
idx_sorted = torch.zeros((B, keep_max_len), device=device, dtype=torch.int32)
lin_out = (req_ids * keep_max_len + local_rank).masked_select(keep_mask)
vals = pos_in_req.to(torch.int32).masked_select(keep_mask)
idx_sorted.view(-1).scatter_(0, lin_out, vals)
return idx_sorted, keep_len.to(torch.int32)
def _compute_prompt_end_indices(
*,
query: torch.Tensor, # [T, Hq, D] scheduled tokens for this step
key_cache: torch.Tensor, # layer KV cache view (platform-dependent)
query_start_loc: torch.Tensor, # [B+1] int32
block_table: torch.Tensor, # [B, max_blocks] int32
prompt_end: torch.Tensor, # [B] bool
prompt_lens: torch.Tensor, # [B] int32
topk_keep: torch.Tensor, # [B] int32
topk_keep_max: Optional[int],
sm_scale: float,
) -> Optional[dict[str, torch.Tensor]]:
"""Compute one-shot prompt compaction indices on the last prefill chunk."""
device = query.device
if prompt_end.numel() == 0:
return None
sel = torch.nonzero(prompt_end, as_tuple=False).flatten()
if int(sel.numel()) == 0:
return None
window = int(envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW)
keep_last = bool(envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN)
protected_prefix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX)
protected_suffix = int(envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX)
# Build packed Q window (last `window` queries per selected request).
sel_list = sel.to(device="cpu", dtype=torch.int64).tolist()
qsl = query_start_loc.to(device="cpu", dtype=torch.int64).tolist()
q_chunks = []
cu_q = [0]
w_list = []
for b in sel_list:
s = int(qsl[b])
e = int(qsl[b + 1])
q_len = max(0, e - s)
win = min(window, q_len)
w_list.append(int(win))
if win > 0:
q_chunks.append(query[e - win:e])
cu_q.append(cu_q[-1] + int(win))
if cu_q[-1] <= 0:
return None
q_packed = torch.cat(q_chunks, dim=0) if q_chunks else query[:0]
cu_seqlens_q = torch.tensor(cu_q, device=device, dtype=torch.int32)
w = torch.tensor(w_list, device=device, dtype=torch.int32)
# Gather full prompt keys for the selected requests into a packed [T, Hk, D].
prompt_lens_sel = prompt_lens.index_select(0, sel).to(torch.int32)
topk_keep_sel = topk_keep.index_select(0, sel).to(torch.int32)
cu_seqlens_k = torch.zeros((int(prompt_lens_sel.numel()) + 1, ),
device=device,
dtype=torch.int32)
if int(prompt_lens_sel.numel()) > 0:
cu_seqlens_k[1:] = torch.cumsum(prompt_lens_sel, dim=0)
block_table_sel = block_table.index_select(0, sel).to(torch.int32)
if not current_platform.is_rocm():
# CUDA cache view: [num_blocks, block_size, H, D] -> [num_blocks, H, block_size, D]
key_cache_view = key_cache.permute(0, 2, 1, 3)
else:
key_cache_view = key_cache
from vllm.v1.attention.kv_compression.kv_cache_triton import (
gather_k_to_packed_triton)
k_packed = gather_k_to_packed_triton(
key_cache_view,
block_table_sel,
prompt_lens_sel,
cu_seqlens_k,
)
# SnapKV Triton scoring (token-shared via sum over KV heads).
from vllm.v1.attention.kv_compression.snapkv_triton import (
query_aware_key_scores)
try:
scores_per_head = query_aware_key_scores(
q=q_packed,
k=k_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
w=w,
sm_scale=float(sm_scale),
pool=False,
protect_last=False,
normalize=False,
)
token_scores = scores_per_head.sum(dim=1)
except Exception:
# Fallback: PyTorch reference scoring (slow but correctness-oriented).
Hq = q_packed.shape[1]
Hk = k_packed.shape[1]
D = q_packed.shape[2]
if Hq % Hk != 0:
raise
group = Hq // Hk
token_scores = torch.zeros((k_packed.shape[0], ),
device=device,
dtype=torch.float32)
for i in range(len(sel_list)):
qs = int(cu_q[i])
qe = int(cu_q[i + 1])
ks = int(cu_seqlens_k[i].item())
ke = int(cu_seqlens_k[i + 1].item())
if qe <= qs or ke <= ks:
continue
q_win = q_packed[qs:qe] # [win, Hq, D]
q_win = q_win.reshape(q_win.shape[0], Hk, group, D).mean(dim=2)
k_all = k_packed[ks:ke]
qh = q_win.permute(1, 0, 2).to(torch.float32)
kh = k_all.permute(1, 0, 2).to(torch.float32)
logits = torch.matmul(qh, kh.transpose(1, 2)) * float(sm_scale)
probs = torch.softmax(logits, dim=-1)
token_scores[ks:ke] = probs.sum(dim=1).sum(dim=0)
from vllm.distributed.parallel_state import get_tp_group
token_scores = get_tp_group().all_reduce(token_scores)
idx_sorted, keep_len = _prompt_end_topk_keep_indices(
token_scores=token_scores,
prompt_lens=prompt_lens_sel,
topk_keep=topk_keep_sel,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
topk_keep_max=topk_keep_max,
)
return {
"req_indices": sel.to(torch.int32),
"idx_sorted": idx_sorted, # [B_sel, K_max] int32
"keep_len": keep_len, # [B_sel] int32
"prompt_lens": prompt_lens_sel, # [B_sel] int32
}
def use_cascade_attention(
common_prefix_len: int,
query_lens: np.ndarray,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment