import math from typing import Optional from packaging import version import torch import triton from triton import language as tl from vllm.kvprune.compression.common import BaseCompressionMethod from vllm.platforms import current_platform from vllm.kvprune.utils.helpers import maybe_execute_in_stream from vllm.kvprune.utils.triton_compat import autotune as triton_autotune # SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py). DEFAULT_SNAPKV_WINDOW_SIZE = 64 DEFAULT_SNAPKV_KERNEL_SIZE = 5 _USE_ROCM_TRITON_DOT_WORKAROUND = current_platform.is_rocm() and ( version.parse(triton.__version__) >= version.parse("3.2.0") ) _ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q = 32 _ROCM_TRITON_DOT_WORKAROUND_BLOCK_K = 32 class SnapKVCompression(BaseCompressionMethod): @staticmethod def pre_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context ) -> Optional[torch.Tensor]: return None @staticmethod def post_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pre_rope_scores: torch.Tensor, context, ) -> Optional[torch.Tensor]: scores = maybe_execute_in_stream( query_aware_key_scores, q, k, context.cu_seqlens_q, context.cu_seqlens_k, w=DEFAULT_SNAPKV_WINDOW_SIZE, kernel_size=DEFAULT_SNAPKV_KERNEL_SIZE, STORE_STREAM=context.STORE_STREAM, ) return scores @triton.jit def _lse_and_store_logits_kernel( Q, K, cu_q, cu_k, w_b, # int32 pointers out_m, out_S, # [B, Hk, ROWS_MAX] float32 LOGITS, # [Nk, Hk, ROWS_MAX] float32 sm_scale, # float QUERY_GROUP_SIZE: tl.constexpr, D: tl.constexpr, STRIDE_Q_NQ, STRIDE_Q_HQ, STRIDE_K_NK, STRIDE_K_HK, STRIDE_M_B, STRIDE_M_H, STRIDE_M_R, STRIDE_S_B, STRIDE_S_H, STRIDE_S_R, STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R, BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr, ROWS_MAX, ROCM_DOT_LAYOUT_WORKAROUND: tl.constexpr, ): # program ids b = tl.program_id(0) hk = tl.program_id(1) rid = tl.program_id(2) # row-tile id # batch segment bounds q_end = tl.load(cu_q + b + 1) k_beg = tl.load(cu_k + b) k_end = tl.load(cu_k + b + 1) win = tl.load(w_b + b) q_win_beg = q_end - win k_eff_end = k_end - win if (win <= 0) or (k_eff_end <= k_beg): return # rows for this (b,hk) rows_b = win * QUERY_GROUP_SIZE row0 = rid * BLOCK_Q if row0 >= rows_b: return # exp(x) = exp2(x * 1/ln2) qk_scale = sm_scale * 1.4426950408889634 offs_qrow = row0 + tl.arange(0, BLOCK_Q) row_mask = offs_qrow < rows_b # map row -> (q_idx, hq_local) hq_local = offs_qrow % QUERY_GROUP_SIZE q_off = offs_qrow // QUERY_GROUP_SIZE q_idx = q_win_beg + q_off hq_glob = hk * QUERY_GROUP_SIZE + hq_local offs_d = tl.arange(0, D) q_ptrs = ( Q + q_idx[:, None] * STRIDE_Q_NQ + hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :] ) q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0) m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf")) S = tl.zeros([BLOCK_Q], dtype=tl.float32) # Full-sequence causal attention (matches kvpress softmax), then use prefix columns only. for ks in tl.range(k_beg, k_end, BLOCK_K): nk = ks + tl.arange(0, BLOCK_K) kmask = nk < k_end if ROCM_DOT_LAYOUT_WORKAROUND: # Triton > 3.1 on ROCm/HCU can fail during LDS optimization when # a dot uses a transposed blocked operand. Load K as [D, BK] and # feed it directly into tl.dot to avoid generating that layout path. k_ptrs = ( K + nk[None, :] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[:, None] ) k_blk = tl.load(k_ptrs, mask=kmask[None, :], other=0.0) # [D, BK] s = tl.dot(q_rows, k_blk) * qk_scale # [BQ, BK] else: k_ptrs = ( K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :] ) k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D] s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK] s = tl.where(kmask[None, :], s, -float("inf")) # Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid). causal_ok = nk[None, :] <= q_idx[:, None] s = tl.where(causal_ok, s, -float("inf")) # store prefix logits only (for marginal probs on prefix keys) log_ptrs = ( LOGITS + nk[:, None] * STRIDE_LG_NK + hk * STRIDE_LG_HK + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R ) store_mask = kmask & (nk < k_eff_end) tl.store(log_ptrs, s.T, mask=store_mask[:, None] & row_mask[None, :]) # log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask) cur_max = tl.max(s, 1) # [BQ] n_m = tl.maximum(m, cur_max) rescale = tl.math.exp2(m - n_m) S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1) m = n_m # store m,S for these rows m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask) tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask) _lse_and_store_logits_kernel_autotuned = triton_autotune( configs=[ triton.Config( {"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages ) for bq in [32, 64] for bk in [32, 64] for num_warps in [4, 8] for num_stages in [3, 4] ], key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"], cache_results=True, )(_lse_and_store_logits_kernel) @triton_autotune( configs=[ triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk}) for bq in [16, 32, 64] for bk in [32, 64, 128] ], key=["HK", "HQ"], cache_results=True, ) @triton.jit def _prefix_probs_kernel( cu_k, w_b, in_m, in_S, # [B, Hk, ROWS_MAX] f32 LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only) PROBS, # [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs # QUERY_GROUP_SIZE: tl.constexpr, STRIDE_M_B, STRIDE_M_H, STRIDE_M_R, STRIDE_S_B, STRIDE_S_H, STRIDE_S_R, STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R, STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R, BLOCK_Q: tl.constexpr, BLOCK_K: tl.constexpr, ): b = tl.program_id(0) hk = tl.program_id(1) k_beg = tl.load(cu_k + b) k_end = tl.load(cu_k + b + 1) win = tl.load(w_b + b) k_eff_end = k_end - win if (win <= 0) or (k_eff_end <= k_beg): return rows_b = win * QUERY_GROUP_SIZE for ks in tl.range(k_beg, k_eff_end, BLOCK_K): nk = ks + tl.arange(0, BLOCK_K) kmask = nk < k_eff_end for row0 in tl.range(0, rows_b, BLOCK_Q): r_idx = row0 + tl.arange(0, BLOCK_Q) rmask = r_idx < rows_b m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R m = tl.load( m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R, mask=rmask, other=-float("inf"), ) S = tl.load( S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0 ) valid_row = S > 0 m = tl.where(valid_row, m, 0.0) S = tl.where(valid_row, S, 1.0) log_ptrs = ( LOGITS + nk[:, None] * STRIDE_LG_NK + hk * STRIDE_LG_HK + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R ) s_T = tl.load( log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf") ) # [BK, BQ] probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :] probs_T = tl.where(valid_row[None, :], probs_T, 0.0) prob_ptrs = ( PROBS + nk[:, None] * STRIDE_PB_NK + hk * STRIDE_PB_HK + (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_PB_R ) tl.store(prob_ptrs, probs_T, mask=kmask[:, None] & rmask[None, :]) @triton_autotune( configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]], key=["HK"], cache_results=True, ) @triton.jit def _zscore_per_batch_epilogue( OUT, # [Nk, Hk], float32 cu_k, w_b, # [B+1], [B] int32 STRIDE_OUT_NK, STRIDE_OUT_HK, HK: tl.constexpr, # Hk EPS: tl.constexpr, # e.g., 1e-12 BLOCK_K: tl.constexpr, # e.g., 128 ): b = tl.program_id(0) k_beg = tl.load(cu_k + b) k_end = tl.load(cu_k + b + 1) win = tl.load(w_b + b) k_eff_end = k_end - win if k_eff_end <= k_beg: return sumv = tl.zeros([], dtype=tl.float32) sumsq = tl.zeros([], dtype=tl.float32) count = ((k_eff_end - k_beg) * HK).to(tl.float32) for ks in tl.range(k_beg, k_eff_end, BLOCK_K): nk = ks + tl.arange(0, BLOCK_K) kmask = nk < k_eff_end for h in tl.range(0, HK): ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32) sumv += tl.sum(vals, 0) sumsq += tl.sum(vals * vals, 0) mean = sumv / count var = tl.maximum(sumsq / count - mean * mean, 0.0) invstd = 1.0 / tl.sqrt(var + EPS) for ks in tl.range(k_beg, k_eff_end, BLOCK_K): nk = ks + tl.arange(0, BLOCK_K) kmask = nk < k_eff_end for h in tl.range(0, HK): ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32) vals = (vals - mean) * invstd tl.store(ptrs, vals, mask=kmask) @triton_autotune( configs=[triton.Config({"BLOCK_T": bt}) for bt in [32, 64, 128, 256]], key=["KERNEL_SIZE"], cache_results=True, ) @triton.jit def _snapkv_avg_pool1d_kernel( IN, OUT, Lp, STRIDE_IN_C, STRIDE_IN_L, STRIDE_OUT_C, STRIDE_OUT_L, KERNEL_SIZE: tl.constexpr, PAD: tl.constexpr, BLOCK_T: tl.constexpr, ): """ Symmetric 1D average pool on the last dimension, matching `F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp] (equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size). """ c = tl.program_id(0) t0 = tl.program_id(1) * BLOCK_T + tl.arange(0, BLOCK_T) mask = t0 < Lp acc = tl.zeros([BLOCK_T], dtype=tl.float32) for j in tl.static_range(KERNEL_SIZE): idx = t0 - PAD + j valid = (idx >= 0) & (idx < Lp) ptrs = IN + c * STRIDE_IN_C + idx * STRIDE_IN_L v = tl.load(ptrs, mask=valid & mask, other=0.0).to(tl.float32) acc += v acc = acc / tl.cast(KERNEL_SIZE, tl.float32) out_ptrs = OUT + c * STRIDE_OUT_C + t0 * STRIDE_OUT_L tl.store(out_ptrs, acc, mask=mask) def _snapkv_avg_pool1d_triton(x: torch.Tensor, kernel_size: int) -> torch.Tensor: """ kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp]. `x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]). """ assert x.dtype == torch.float32 Hk, G, Lp = x.shape if Lp == 0: return x pad = kernel_size // 2 x2 = x.reshape(Hk * G, Lp).contiguous() out = torch.empty_like(x2) C = Hk * G si_c, si_l = x2.stride() so_c, so_l = out.stride() def grid(meta): return (C, triton.cdiv(Lp, meta["BLOCK_T"])) _snapkv_avg_pool1d_kernel[grid]( x2, out, Lp, si_c, si_l, so_c, so_l, KERNEL_SIZE=kernel_size, PAD=pad, ) return out.view(Hk, G, Lp) def _snapkv_kvpress_epilogue( probs_buf: torch.Tensor, out: torch.Tensor, cu_seqlens_k: torch.Tensor, w: torch.Tensor, G: int, Hk: int, kernel_size: int, ) -> None: """ Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d → mean over GQA groups → pad tail with global max of prefix scores. """ B = cu_seqlens_k.numel() - 1 for b in range(B): k_beg = int(cu_seqlens_k[b].item()) k_end = int(cu_seqlens_k[b + 1].item()) win = int(w[b].item()) k_eff_end = k_end - win if win <= 0 or k_eff_end <= k_beg: continue Lp = k_eff_end - k_beg rows_b = win * G p = probs_buf[k_beg:k_eff_end, :, :rows_b] # [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout x = p.view(Lp, Hk, win, G).mean(dim=2) x = x.permute(1, 2, 0).contiguous() # [Hk, G, Lp] x = _snapkv_avg_pool1d_triton(x, kernel_size) x = x.mean(dim=1) seg = x.permute(1, 0).contiguous() out[k_beg:k_eff_end, :] = seg pad_val = seg.max() out[k_eff_end:k_end, :] = pad_val def query_aware_key_scores( q: torch.Tensor, # [N_q, Hq, D] k: torch.Tensor, # [N_k, Hk, D] cu_seqlens_q: torch.Tensor, # [B+1], int32 cu_seqlens_k: torch.Tensor, # [B+1], int32 w: torch.Tensor | int, # [B], int32 sm_scale: float = None, # defaults to 1/sqrt(D) *, kernel_size: int = DEFAULT_SNAPKV_KERNEL_SIZE, accum_scores: torch.Tensor = None, accum_blending: float = None, normalize: bool = False, ) -> Optional[torch.Tensor]: assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous" device = q.device N_q, Hq, D = q.shape N_k, Hk, Dk = k.shape assert (Hq % Hk) == 0, "Hq must be a multiple of Hk" if sm_scale is None: sm_scale = 1.0 / math.sqrt(D) B = cu_seqlens_q.numel() - 1 assert B == cu_seqlens_k.numel() - 1 G = Hq // Hk if type(w) is int: max_w = w w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32) else: max_w = int(w.max().item()) assert w.numel() == B ROWS_MAX = max_w * G if ROWS_MAX == 0: return torch.zeros((N_k, Hk), dtype=torch.float32, device=device) out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device) m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device) S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device) logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device) probs_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device) # strides STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride() STRIDE_K_NK, STRIDE_K_HK, _ = k.stride() STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride() STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride() STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride() STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R = probs_buf.stride() STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride() def grid(META): return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"]) lse_kernel = _lse_and_store_logits_kernel_autotuned lse_kernel_kwargs = {} if _USE_ROCM_TRITON_DOT_WORKAROUND: lse_kernel = _lse_and_store_logits_kernel lse_kernel_kwargs = { "BLOCK_Q": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q, "BLOCK_K": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_K, "num_warps": 4, "num_stages": 1, } lse_kernel[grid]( q, k, cu_seqlens_q, cu_seqlens_k, w, m_scratch, S_scratch, logits_buf, sm_scale, QUERY_GROUP_SIZE=Hq // Hk, D=D, STRIDE_Q_NQ=STRIDE_Q_NQ, STRIDE_Q_HQ=STRIDE_Q_HQ, STRIDE_K_NK=STRIDE_K_NK, STRIDE_K_HK=STRIDE_K_HK, STRIDE_M_B=STRIDE_M_B, STRIDE_M_H=STRIDE_M_H, STRIDE_M_R=STRIDE_M_R, STRIDE_S_B=STRIDE_S_B, STRIDE_S_H=STRIDE_S_H, STRIDE_S_R=STRIDE_S_R, STRIDE_LG_NK=STRIDE_LG_NK, STRIDE_LG_HK=STRIDE_LG_HK, STRIDE_LG_R=STRIDE_LG_R, ROWS_MAX=ROWS_MAX, ROCM_DOT_LAYOUT_WORKAROUND=_USE_ROCM_TRITON_DOT_WORKAROUND, **lse_kernel_kwargs, ) _prefix_probs_kernel[(B, Hk)]( cu_seqlens_k, w, m_scratch, S_scratch, logits_buf, probs_buf, QUERY_GROUP_SIZE=Hq // Hk, STRIDE_M_B=STRIDE_M_B, STRIDE_M_H=STRIDE_M_H, STRIDE_M_R=STRIDE_M_R, STRIDE_S_B=STRIDE_S_B, STRIDE_S_H=STRIDE_S_H, STRIDE_S_R=STRIDE_S_R, STRIDE_LG_NK=STRIDE_LG_NK, STRIDE_LG_HK=STRIDE_LG_HK, STRIDE_LG_R=STRIDE_LG_R, STRIDE_PB_NK=STRIDE_PB_NK, STRIDE_PB_HK=STRIDE_PB_HK, STRIDE_PB_R=STRIDE_PB_R, ) _snapkv_kvpress_epilogue( probs_buf, out, cu_seqlens_k, w, G, Hk, kernel_size ) if normalize: _zscore_per_batch_epilogue[(B,)]( out, cu_seqlens_k, w, STRIDE_OUT_NK, STRIDE_OUT_HK, HK=Hk, EPS=1e-12, ) if accum_scores is not None: if accum_blending is not None: accum_scores.mul_(accum_blending) accum_scores.add_(out) return accum_scores else: return out