""" CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上, 用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。 预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的 (token, head) 对数量)及首/尾保护段长度。 注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``, 与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。 """ from __future__ import annotations from typing import Any, Optional, Tuple import torch import triton from triton import language as tl from vllm.kvprune.compression.common import BaseCompressionMethod from vllm.kvprune.compression.compactor import ( CompactorCompression, non_causal_attn_scores, ) from vllm.kvprune.compression.snapkv import SnapKVCompression from vllm.kvprune.utils.helpers import maybe_execute_in_stream from vllm.kvprune.utils.triton_compat import autotune as triton_autotune # ============================================================================ # Triton Kernel 1: 计算 ||Wo @ V||₁ (L1 范数) # ============================================================================ @triton_autotune( configs=[ triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns) for bk in [32, 64, 128] for bd in [32, 64] for nw in [4, 8] for ns in [3, 4] ], key=["Hk", "D", "HIDDEN"], cache_results=True, ) @triton.jit def _compute_wo_v_l1_kernel( V, WO, cu_k, OUT, STRIDE_V_NK, STRIDE_V_HK, STRIDE_V_D, STRIDE_WO_HQ, STRIDE_WO_D, STRIDE_WO_HID, STRIDE_OUT_NK, STRIDE_OUT_HK, Hk: tl.constexpr, Hq: tl.constexpr, D: tl.constexpr, HIDDEN: tl.constexpr, QUERY_GROUP_SIZE: tl.constexpr, BLOCK_K: tl.constexpr, BLOCK_D: tl.constexpr, ): b = tl.program_id(0) hk = tl.program_id(1) ks = tl.program_id(2) k_beg = tl.load(cu_k + b) k_end = tl.load(cu_k + b + 1) nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K) nk = k_beg + nk_off k_mask = nk < k_end out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32) for g in range(QUERY_GROUP_SIZE): hq = hk * QUERY_GROUP_SIZE + g v_ptrs = ( V + nk[:, None] * STRIDE_V_NK + hk * STRIDE_V_HK + tl.arange(0, D)[None, :] * STRIDE_V_D ) v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32) for hid_off in range(0, HIDDEN, BLOCK_D): hid_idx = hid_off + tl.arange(0, BLOCK_D) hid_mask = hid_idx < HIDDEN wo_ptrs = ( WO + hq * STRIDE_WO_HQ + tl.arange(0, D)[:, None] * STRIDE_WO_D + hid_idx[None, :] * STRIDE_WO_HID ) wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32) wov_tile = tl.dot(v_blk, wo_tile) l1_sum += tl.sum(tl.abs(wov_tile), axis=1) l1_sum = l1_sum / QUERY_GROUP_SIZE tl.store(out_ptrs, l1_sum, mask=k_mask) # ============================================================================ # Triton Kernel 2: Stage 1 保护 + Stage 2 加权融合 # ============================================================================ @triton_autotune( configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]], key=["Hk"], cache_results=True, ) @triton.jit def _critical_ada_fuse_kernel( BASE_SCORES, WO_V_NORM, STAGE1_MASK, cu_k, OUT, EPSILON: tl.constexpr, STRIDE_BS_NK, STRIDE_BS_HK, STRIDE_WN_NK, STRIDE_WN_HK, STRIDE_S1_NK, STRIDE_S1_HK, STRIDE_OUT_NK, STRIDE_OUT_HK, Hk: tl.constexpr, BLOCK_K: tl.constexpr, ): b = tl.program_id(0) hk = tl.program_id(1) k_beg = tl.load(cu_k + b) k_end = tl.load(cu_k + b + 1) for ks in tl.range(k_beg, k_end, BLOCK_K): nk = ks + tl.arange(0, BLOCK_K) kmask = nk < k_end bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK base = tl.load(bs_ptrs, mask=kmask, other=0.0) wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0) stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32) fused = (base + EPSILON) * wnorm fused = tl.where(stage1_protect == 1, float("inf"), fused) out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK tl.store(out_ptrs, fused, mask=kmask) def critical_ada_key_scores( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, wo_weight: torch.Tensor, cu_seqlens: torch.Tensor, base_scores: torch.Tensor, compression_ctx: Any, *, store_stream: Optional[torch.cuda.Stream] = None, ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]: """ 使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数), 在每条序列上尽量贴近 kvpress 的 CriticalAdaKV 语义: 1) alpha_safeguard 安全预算(每头至少保留一部分); 2) 基于 base_scores 的 head-wise 自适应预算分配(head_budgets); 3) Stage-1 按 head_budgets * first_stage_ratio 保护; 4) Stage-2 计算 ``(base + eps) * ||Wo@V||_1``,再按 head_budgets 做每头 top-k 保护。 Args: compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含 ``batch_tokens_to_retain``、``protected_first_tokens``、``protected_last_tokens``; 可选 ``critical_ada_epsilon``、``critical_ada_first_stage_ratio``、 ``critical_ada_alpha_safeguard``。 """ assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1 device = q.device _, Hq, D = q.shape N_k, Hk, Dk = k.shape assert D == Dk and Hq % Hk == 0 # 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q), # 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。 B = cu_seqlens.numel() - 1 G = Hq // Hk k_lengths = cu_seqlens[1:] - cu_seqlens[:-1] btr = compression_ctx.batch_tokens_to_retain assert btr is not None and btr.numel() == B btr = btr.to(device=device, dtype=torch.int32) prot_first = compression_ctx.protected_first_tokens or [0] * B prot_last = compression_ctx.protected_last_tokens or [0] * B epsilon = compression_ctx.critical_ada_epsilon first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio alpha_safeguard = float(getattr(compression_ctx, "critical_ada_alpha_safeguard", 0.2)) alpha_safeguard = max(0.0, min(1.0, alpha_safeguard)) if wo_weight.dim() == 2: hidden_size, _ = wo_weight.shape wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous() else: wo = wo_weight.contiguous() hidden_size = wo.size(-1) wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device) def grid_wo(META): max_k_len = int(k_lengths.max().item()) return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"])) _compute_wo_v_l1_kernel[grid_wo]( v, wo, cu_seqlens, wo_v_norm, *v.stride(), *wo.stride(), *wo_v_norm.stride(), Hk=Hk, Hq=Hq, D=D, HIDDEN=hidden_size, QUERY_GROUP_SIZE=G, ) stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device) # kvpress 风格的每头预算(按序列自适应),用于 Stage-1/Stage-2。 head_budgets_by_batch = [] for b in range(B): k_len = int(k_lengths[b].item()) if k_len == 0: head_budgets_by_batch.append(None) continue k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) s = int(prot_first[b]) if b < len(prot_first) else 0 e = int(prot_last[b]) if b < len(prot_last) else 0 lo, hi = k_beg + s, k_end - e compressible = max(0, hi - lo) keep_pairs = int(btr[b].item()) if compressible <= 0: head_budgets_by_batch.append(None) continue # 每头 token 预算(kvpress 的 n_kept) n_kept_tokens = max(1, keep_pairs // Hk) n_kept_tokens = min(n_kept_tokens, compressible) # 安全预算(每头至少保留 n_safe) n_safe = int(n_kept_tokens * alpha_safeguard) if n_safe > 0: tk_safe = min(n_safe, compressible) for hk in range(Hk): safe_idx = torch.topk(base_scores[lo:hi, hk], tk_safe, sorted=False).indices stage1_mask[lo + safe_idx, hk] = 1 # 自适应预算分配:在扁平 (token, head) 空间取 top n_kept_tokens*Hk,统计每个 head 的预算 budget_scores = base_scores[lo:hi, :].clone() if n_safe > 0: budget_scores[stage1_mask[lo:hi, :] == 1] = float("inf") top_pairs = min(n_kept_tokens * Hk, budget_scores.numel()) if top_pairs <= 0: head_budgets_by_batch.append(None) continue top_idx_flat = torch.topk( budget_scores.reshape(-1), top_pairs, sorted=False ).indices top_head_idx = top_idx_flat % Hk head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32) head_budgets_by_batch.append(head_budgets) # Stage-1:按 head_budgets 的 first_stage_ratio 分头保护(kvpress 语义) for hk in range(Hk): phase1_budget = int(head_budgets[hk].item() * first_stage_ratio) if phase1_budget <= 0: continue tk = min(phase1_budget, compressible) top_idx = torch.topk(base_scores[lo:hi, hk], tk, sorted=False).indices stage1_mask[lo + top_idx, hk] = 1 final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device) def grid_fuse(_META): return (B, Hk) _critical_ada_fuse_kernel[grid_fuse]( base_scores, wo_v_norm, stage1_mask, cu_seqlens, final_scores, EPSILON=epsilon, *base_scores.stride(), *wo_v_norm.stride(), *stage1_mask.stride(), *final_scores.stride(), Hk=Hk, ) # Stage-2(kvpress 语义):在融合后按每头预算再做一次 top-k 保护。 for b in range(B): hb = head_budgets_by_batch[b] if hb is None: continue k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) s = int(prot_first[b]) if b < len(prot_first) else 0 e = int(prot_last[b]) if b < len(prot_last) else 0 lo, hi = k_beg + s, k_end - e if hi <= lo: continue region_len = hi - lo for hk in range(Hk): budget = int(hb[hk].item()) if budget <= 0: continue tk = min(budget, region_len) idx = torch.topk(final_scores[lo:hi, hk], tk, sorted=False).indices final_scores[lo + idx, hk] = float("inf") masked_key_indices = None for b in range(B): k_len = int(k_lengths[b].item()) if k_len == 0: continue keep_pairs = int(btr[b].item()) total_pairs = k_len * Hk if keep_pairs >= total_pairs: continue k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) n_prune_pairs = min(total_pairs - keep_pairs, total_pairs) if n_prune_pairs <= 0: continue flat_scores = final_scores[k_beg:k_end, :].reshape(-1) prune_idx = torch.topk( -flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False ).indices batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64) head_idx = prune_idx % Hk seq_idx = prune_idx // Hk + k_beg if masked_key_indices is None: masked_key_indices = (batch_idx, head_idx, seq_idx) else: masked_key_indices = ( torch.cat([masked_key_indices[0], batch_idx]), torch.cat([masked_key_indices[1], head_idx]), torch.cat([masked_key_indices[2], seq_idx]), ) if store_stream is not None: final_scores.record_stream(store_stream) return final_scores, masked_key_indices class CriticalAdaKVCompression(BaseCompressionMethod): """ 以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合), 再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。 """ @staticmethod def pre_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context ) -> Optional[torch.Tensor]: cc = context.compression_context base = getattr(cc, "critical_ada_base_scorer", "compactor") if cc is not None else "compactor" if str(base).lower() == "snapkv": return SnapKVCompression.pre_rope_scoring(q, k, v, context) return CompactorCompression.pre_rope_scoring(q, k, v, context) @staticmethod def post_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pre_rope_scores: Optional[torch.Tensor], context, ) -> Optional[torch.Tensor]: compression_context = context.compression_context assert compression_context is not None base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower() if base == "snapkv": base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context) else: # 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致: # maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...) # 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。 if context.STORE_STREAM is not None: torch.cuda.current_stream().wait_stream(context.STORE_STREAM) base_scores = maybe_execute_in_stream( non_causal_attn_scores, q, k, v, context.cu_seqlens_q, context.max_seqlen_q, chunk_size=CompactorCompression.chunk_size, sm_scale=1.0, normalize=True, accum_scores=pre_rope_scores, context_lens=compression_context.context_lens, protected_first_tokens=compression_context.protected_first_tokens, protected_last_tokens=compression_context.protected_last_tokens, accum_blending=0.5, ) wo_weight = compression_context.wo_weight if wo_weight is None: return base_scores scores, _masked = maybe_execute_in_stream( critical_ada_key_scores, q, k, v, wo_weight, context.cu_seqlens_q, base_scores, compression_context, STORE_STREAM=context.STORE_STREAM, store_stream=context.STORE_STREAM, ) return scores @staticmethod def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype): """可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。""" if not hasattr(module, "o_proj") or module.o_proj.weight is None: return if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"): return wo_raw = module.o_proj.weight.data hidden_size, _ = wo_raw.shape Hq = module.num_heads head_dim = module.head_dim wo = ( wo_raw.transpose(0, 1) .view(Hq, head_dim, hidden_size) .to(device=device, dtype=torch.float32) ) module._critical_ada_wo_weight = wo