Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5c1e8006
Commit
5c1e8006
authored
May 07, 2026
by
chenzk
Browse files
vllm kvprune for tritonx:v1.1.4
parent
40c2c5e5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
0 additions
and
579 deletions
+0
-579
vllm/kvprune/compression/snapkv.py.bak
vllm/kvprune/compression/snapkv.py.bak
+0
-579
No files found.
vllm/kvprune/compression/snapkv.py.bak
deleted
100644 → 0
View file @
40c2c5e5
import math
from typing import Optional
from packaging import version
import torch
import triton
from triton import language as tl
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.platforms import current_platform
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
DEFAULT_SNAPKV_WINDOW_SIZE = 64
DEFAULT_SNAPKV_KERNEL_SIZE = 5
_USE_ROCM_TRITON_DOT_WORKAROUND = current_platform.is_rocm() and (
version.parse(triton.__version__) >= version.parse("3.2.0")
)
_ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q = 32
_ROCM_TRITON_DOT_WORKAROUND_BLOCK_K = 32
class SnapKVCompression(BaseCompressionMethod):
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
return None
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
scores = maybe_execute_in_stream(
query_aware_key_scores,
q,
k,
context.cu_seqlens_q,
context.cu_seqlens_k,
w=DEFAULT_SNAPKV_WINDOW_SIZE,
kernel_size=DEFAULT_SNAPKV_KERNEL_SIZE,
STORE_STREAM=context.STORE_STREAM,
)
return scores
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b, # int32 pointers
out_m,
out_S, # [B, Hk, ROWS_MAX] float32
LOGITS, # [Nk, Hk, ROWS_MAX] float32
sm_scale, # float
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
ROCM_DOT_LAYOUT_WORKAROUND: tl.constexpr,
):
# program ids
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2) # row-tile id
# batch segment bounds
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
# rows for this (b,hk)
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale = sm_scale * 1.4426950408889634
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
# map row -> (q_idx, hq_local)
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (
Q
+ q_idx[:, None] * STRIDE_Q_NQ
+ hq_glob[:, None] * STRIDE_Q_HQ
+ offs_d[None, :]
)
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
# Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
if ROCM_DOT_LAYOUT_WORKAROUND:
# Triton > 3.1 on ROCm/HCU can fail during LDS optimization when
# a dot uses a transposed blocked operand. Load K as [D, BK] and
# feed it directly into tl.dot to avoid generating that layout path.
k_ptrs = (
K + nk[None, :] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[:, None]
)
k_blk = tl.load(k_ptrs, mask=kmask[None, :], other=0.0) # [D, BK]
s = tl.dot(q_rows, k_blk) * qk_scale # [BQ, BK]
else:
k_ptrs = (
K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
)
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
s = tl.where(kmask[None, :], s, -float("inf"))
# Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
causal_ok = nk[None, :] <= q_idx[:, None]
s = tl.where(causal_ok, s, -float("inf"))
# store prefix logits only (for marginal probs on prefix keys)
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
store_mask = kmask & (nk < k_eff_end)
tl.store(log_ptrs, s.T, mask=store_mask[:, None] & row_mask[None, :])
# log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
cur_max = tl.max(s, 1) # [BQ]
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
# store m,S for these rows
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
_lse_and_store_logits_kernel_autotuned = triton_autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
)
for bq in [32, 64]
for bk in [32, 64]
for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
cache_results=True,
)(_lse_and_store_logits_kernel)
@triton_autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64]
for bk in [32, 64, 128]
],
key=["HK", "HQ"],
cache_results=True,
)
@triton.jit
def _prefix_probs_kernel(
cu_k,
w_b,
in_m,
in_S, # [B, Hk, ROWS_MAX] f32
LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
PROBS, # [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
#
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_PB_NK,
STRIDE_PB_HK,
STRIDE_PB_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
m = tl.load(
m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"),
)
S = tl.load(
S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
s_T = tl.load(
log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
) # [BK, BQ]
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
prob_ptrs = (
PROBS
+ nk[:, None] * STRIDE_PB_NK
+ hk * STRIDE_PB_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_PB_R
)
tl.store(prob_ptrs, probs_T, mask=kmask[:, None] & rmask[None, :])
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
cache_results=True,
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT, # [Nk, Hk], float32
cu_k,
w_b, # [B+1], [B] int32
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr, # Hk
EPS: tl.constexpr, # e.g., 1e-12
BLOCK_K: tl.constexpr, # e.g., 128
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
@triton_autotune(
configs=[triton.Config({"BLOCK_T": bt}) for bt in [32, 64, 128, 256]],
key=["KERNEL_SIZE"],
cache_results=True,
)
@triton.jit
def _snapkv_avg_pool1d_kernel(
IN,
OUT,
Lp,
STRIDE_IN_C,
STRIDE_IN_L,
STRIDE_OUT_C,
STRIDE_OUT_L,
KERNEL_SIZE: tl.constexpr,
PAD: tl.constexpr,
BLOCK_T: tl.constexpr,
):
"""
Symmetric 1D average pool on the last dimension, matching
`F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
(equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
"""
c = tl.program_id(0)
t0 = tl.program_id(1) * BLOCK_T + tl.arange(0, BLOCK_T)
mask = t0 < Lp
acc = tl.zeros([BLOCK_T], dtype=tl.float32)
for j in tl.static_range(KERNEL_SIZE):
idx = t0 - PAD + j
valid = (idx >= 0) & (idx < Lp)
ptrs = IN + c * STRIDE_IN_C + idx * STRIDE_IN_L
v = tl.load(ptrs, mask=valid & mask, other=0.0).to(tl.float32)
acc += v
acc = acc / tl.cast(KERNEL_SIZE, tl.float32)
out_ptrs = OUT + c * STRIDE_OUT_C + t0 * STRIDE_OUT_L
tl.store(out_ptrs, acc, mask=mask)
def _snapkv_avg_pool1d_triton(x: torch.Tensor, kernel_size: int) -> torch.Tensor:
"""
kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
`x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
"""
assert x.dtype == torch.float32
Hk, G, Lp = x.shape
if Lp == 0:
return x
pad = kernel_size // 2
x2 = x.reshape(Hk * G, Lp).contiguous()
out = torch.empty_like(x2)
C = Hk * G
si_c, si_l = x2.stride()
so_c, so_l = out.stride()
def grid(meta):
return (C, triton.cdiv(Lp, meta["BLOCK_T"]))
_snapkv_avg_pool1d_kernel[grid](
x2,
out,
Lp,
si_c,
si_l,
so_c,
so_l,
KERNEL_SIZE=kernel_size,
PAD=pad,
)
return out.view(Hk, G, Lp)
def _snapkv_kvpress_epilogue(
probs_buf: torch.Tensor,
out: torch.Tensor,
cu_seqlens_k: torch.Tensor,
w: torch.Tensor,
G: int,
Hk: int,
kernel_size: int,
) -> None:
"""
Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
→ mean over GQA groups → pad tail with global max of prefix scores.
"""
B = cu_seqlens_k.numel() - 1
for b in range(B):
k_beg = int(cu_seqlens_k[b].item())
k_end = int(cu_seqlens_k[b + 1].item())
win = int(w[b].item())
k_eff_end = k_end - win
if win <= 0 or k_eff_end <= k_beg:
continue
Lp = k_eff_end - k_beg
rows_b = win * G
p = probs_buf[k_beg:k_eff_end, :, :rows_b]
# [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
x = p.view(Lp, Hk, win, G).mean(dim=2)
x = x.permute(1, 2, 0).contiguous() # [Hk, G, Lp]
x = _snapkv_avg_pool1d_triton(x, kernel_size)
x = x.mean(dim=1)
seg = x.permute(1, 0).contiguous()
out[k_beg:k_eff_end, :] = seg
pad_val = seg.max()
out[k_eff_end:k_end, :] = pad_val
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1], int32
cu_seqlens_k: torch.Tensor, # [B+1], int32
w: torch.Tensor | int, # [B], int32
sm_scale: float = None, # defaults to 1/sqrt(D)
*,
kernel_size: int = DEFAULT_SNAPKV_KERNEL_SIZE,
accum_scores: torch.Tensor = None,
accum_blending: float = None,
normalize: bool = False,
) -> Optional[torch.Tensor]:
assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = cu_seqlens_q.numel() - 1
assert B == cu_seqlens_k.numel() - 1
G = Hq // Hk
if type(w) is int:
max_w = w
w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
else:
max_w = int(w.max().item())
assert w.numel() == B
ROWS_MAX = max_w * G
if ROWS_MAX == 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
probs_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
# strides
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R = probs_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(META):
return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
lse_kernel = _lse_and_store_logits_kernel_autotuned
lse_kernel_kwargs = {}
if _USE_ROCM_TRITON_DOT_WORKAROUND:
lse_kernel = _lse_and_store_logits_kernel
lse_kernel_kwargs = {
"BLOCK_Q": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q,
"BLOCK_K": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_K,
"num_warps": 4,
"num_stages": 1,
}
lse_kernel[grid](
q,
k,
cu_seqlens_q,
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=Hq // Hk,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=ROWS_MAX,
ROCM_DOT_LAYOUT_WORKAROUND=_USE_ROCM_TRITON_DOT_WORKAROUND,
**lse_kernel_kwargs,
)
_prefix_probs_kernel[(B, Hk)](
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
probs_buf,
QUERY_GROUP_SIZE=Hq // Hk,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_PB_NK=STRIDE_PB_NK,
STRIDE_PB_HK=STRIDE_PB_HK,
STRIDE_PB_R=STRIDE_PB_R,
)
_snapkv_kvpress_epilogue(
probs_buf, out, cu_seqlens_k, w, G, Hk, kernel_size
)
if normalize:
_zscore_per_batch_epilogue[(B,)](
out,
cu_seqlens_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
if accum_scores is not None:
if accum_blending is not None:
accum_scores.mul_(accum_blending)
accum_scores.add_(out)
return accum_scores
else:
return out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment