Commit 40c2c5e5 authored by chenzk's avatar chenzk
Browse files

vllm kvprune for tritonx:v1.1.3

parent 58666cd7
...@@ -49,7 +49,21 @@ More related libraries: ...@@ -49,7 +49,21 @@ More related libraries:
- flash_attn-2.8.3+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl - flash_attn-2.8.3+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
- torchvision-0.24.0+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl - torchvision-0.24.0+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
- triton-3.1.0+das.opt1.dtk2604.torch271-cp310-cp310-manylinux_2_28_x86_64.whl - triton-3.3.0+das.opt1.dtk2604.torch290-cp310-cp310-manylinux_2_28_x86_64.whl
This project is compatible with triton-3.1.0, triton-3.3.0, and triton-3.5.1. However, for triton-3.5.1, when the underlying environment uses clang 17 and LLVM 22.0, the following modifications are required due to triton's own compatibility issues:
In /usr/local/lib/python3.10/dist-packages/triton/backends/amd/compiler.py, locate the make_llir(src, metadata, options) function within the HIPBackend(BaseBackend) class. Replace `return str(llvm_mod)` with
```
# compatibility fix for clang 17 + LLVM 22.0
llir = str(llvm_mod)
llir = re.sub(r"getelementptr inbounds\s+nuw\s+", "getelementptr inbounds ", llir)
llir = re.sub(r"getelementptr\s+nuw\s+", "getelementptr ", llir)
llir = re.sub(r"getelementptr inbounds\s+nusw\s+", "getelementptr inbounds ", llir)
llir = re.sub(r"getelementptr\s+nusw\s+", "getelementptr ", llir)
return llir
```
## Quick Start ## Quick Start
Basic Chat Generation with Compression: Basic Chat Generation with Compression:
...@@ -140,7 +154,7 @@ def main() -> None: ...@@ -140,7 +154,7 @@ def main() -> None:
compression = [ compression = [
CompressionParams( CompressionParams(
compression_ratio=0.5, compression_ratio=0.5,
compression_method="compactor", compression_method="snapkv",
), ),
] ]
......
import math import math
from typing import Optional from typing import Optional
from packaging import version
import torch import torch
import triton import triton
from triton import language as tl from triton import language as tl
from vllm.kvprune.compression.common import BaseCompressionMethod from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.platforms import current_platform
from vllm.kvprune.utils.helpers import maybe_execute_in_stream from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py). # SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
DEFAULT_SNAPKV_WINDOW_SIZE = 64 DEFAULT_SNAPKV_WINDOW_SIZE = 64
DEFAULT_SNAPKV_KERNEL_SIZE = 5 DEFAULT_SNAPKV_KERNEL_SIZE = 5
# ROCm + Triton >= 3.2.0 may fail on this kernel when tl.dot consumes a
# transposed blocked operand. Work around it here by loading K as [D, BK] and
# using a fixed launch config. Triton 3.5.1 has a separate LLVM/DTK toolchain
# incompatibility that should be hot-patched in triton/backends/amd/compiler.py
# instead of adding more SnapKV-side branching.
_USE_ROCM_TRITON_DOT_WORKAROUND = current_platform.is_rocm() and (
version.parse(triton.__version__) >= version.parse("3.2.0")
)
_ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q = 32
_ROCM_TRITON_DOT_WORKAROUND_BLOCK_K = 32
class SnapKVCompression(BaseCompressionMethod): class SnapKVCompression(BaseCompressionMethod):
...@@ -42,19 +54,6 @@ class SnapKVCompression(BaseCompressionMethod): ...@@ -42,19 +54,6 @@ class SnapKVCompression(BaseCompressionMethod):
return scores return scores
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
)
for bq in [32, 64]
for bk in [32, 64]
for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
cache_results=True,
)
@triton.jit @triton.jit
def _lse_and_store_logits_kernel( def _lse_and_store_logits_kernel(
Q, Q,
...@@ -84,6 +83,7 @@ def _lse_and_store_logits_kernel( ...@@ -84,6 +83,7 @@ def _lse_and_store_logits_kernel(
BLOCK_Q: tl.constexpr, BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr, BLOCK_K: tl.constexpr,
ROWS_MAX, ROWS_MAX,
ROCM_DOT_LAYOUT_WORKAROUND: tl.constexpr,
): ):
# program ids # program ids
b = tl.program_id(0) b = tl.program_id(0)
...@@ -135,10 +135,21 @@ def _lse_and_store_logits_kernel( ...@@ -135,10 +135,21 @@ def _lse_and_store_logits_kernel(
nk = ks + tl.arange(0, BLOCK_K) nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end kmask = nk < k_end
k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :] if ROCM_DOT_LAYOUT_WORKAROUND:
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D] # Triton > 3.1 on ROCm/HCU can fail during LDS optimization when
# a dot uses a transposed blocked operand. Load K as [D, BK] and
s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK] # feed it directly into tl.dot to avoid generating that layout path.
k_ptrs = (
K + nk[None, :] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[:, None]
)
k_blk = tl.load(k_ptrs, mask=kmask[None, :], other=0.0) # [D, BK]
s = tl.dot(q_rows, k_blk) * qk_scale # [BQ, BK]
else:
k_ptrs = (
K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
)
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
s = tl.where(kmask[None, :], s, -float("inf")) s = tl.where(kmask[None, :], s, -float("inf"))
# Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid). # Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
causal_ok = nk[None, :] <= q_idx[:, None] causal_ok = nk[None, :] <= q_idx[:, None]
...@@ -168,6 +179,21 @@ def _lse_and_store_logits_kernel( ...@@ -168,6 +179,21 @@ def _lse_and_store_logits_kernel(
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask) tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
_lse_and_store_logits_kernel_autotuned = triton_autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
)
for bq in [32, 64]
for bk in [32, 64]
for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
cache_results=True,
)(_lse_and_store_logits_kernel)
@triton_autotune( @triton_autotune(
configs=[ configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk}) triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
...@@ -474,7 +500,18 @@ def query_aware_key_scores( ...@@ -474,7 +500,18 @@ def query_aware_key_scores(
def grid(META): def grid(META):
return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"]) return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
_lse_and_store_logits_kernel[grid]( lse_kernel = _lse_and_store_logits_kernel_autotuned
lse_kernel_kwargs = {}
if _USE_ROCM_TRITON_DOT_WORKAROUND:
lse_kernel = _lse_and_store_logits_kernel
lse_kernel_kwargs = {
"BLOCK_Q": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q,
"BLOCK_K": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_K,
"num_warps": 4,
"num_stages": 1,
}
lse_kernel[grid](
q, q,
k, k,
cu_seqlens_q, cu_seqlens_q,
...@@ -500,6 +537,8 @@ def query_aware_key_scores( ...@@ -500,6 +537,8 @@ def query_aware_key_scores(
STRIDE_LG_HK=STRIDE_LG_HK, STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R, STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=ROWS_MAX, ROWS_MAX=ROWS_MAX,
ROCM_DOT_LAYOUT_WORKAROUND=_USE_ROCM_TRITON_DOT_WORKAROUND,
**lse_kernel_kwargs,
) )
_prefix_probs_kernel[(B, Hk)]( _prefix_probs_kernel[(B, Hk)](
...@@ -543,4 +582,3 @@ def query_aware_key_scores( ...@@ -543,4 +582,3 @@ def query_aware_key_scores(
return accum_scores return accum_scores
else: else:
return out return out
import math
from typing import Optional
from packaging import version
import torch
import triton
from triton import language as tl
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.platforms import current_platform
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
DEFAULT_SNAPKV_WINDOW_SIZE = 64
DEFAULT_SNAPKV_KERNEL_SIZE = 5
_USE_ROCM_TRITON_DOT_WORKAROUND = current_platform.is_rocm() and (
version.parse(triton.__version__) >= version.parse("3.2.0")
)
_ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q = 32
_ROCM_TRITON_DOT_WORKAROUND_BLOCK_K = 32
class SnapKVCompression(BaseCompressionMethod):
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
return None
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
scores = maybe_execute_in_stream(
query_aware_key_scores,
q,
k,
context.cu_seqlens_q,
context.cu_seqlens_k,
w=DEFAULT_SNAPKV_WINDOW_SIZE,
kernel_size=DEFAULT_SNAPKV_KERNEL_SIZE,
STORE_STREAM=context.STORE_STREAM,
)
return scores
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b, # int32 pointers
out_m,
out_S, # [B, Hk, ROWS_MAX] float32
LOGITS, # [Nk, Hk, ROWS_MAX] float32
sm_scale, # float
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
ROCM_DOT_LAYOUT_WORKAROUND: tl.constexpr,
):
# program ids
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2) # row-tile id
# batch segment bounds
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
# rows for this (b,hk)
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale = sm_scale * 1.4426950408889634
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
# map row -> (q_idx, hq_local)
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (
Q
+ q_idx[:, None] * STRIDE_Q_NQ
+ hq_glob[:, None] * STRIDE_Q_HQ
+ offs_d[None, :]
)
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
# Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
if ROCM_DOT_LAYOUT_WORKAROUND:
# Triton > 3.1 on ROCm/HCU can fail during LDS optimization when
# a dot uses a transposed blocked operand. Load K as [D, BK] and
# feed it directly into tl.dot to avoid generating that layout path.
k_ptrs = (
K + nk[None, :] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[:, None]
)
k_blk = tl.load(k_ptrs, mask=kmask[None, :], other=0.0) # [D, BK]
s = tl.dot(q_rows, k_blk) * qk_scale # [BQ, BK]
else:
k_ptrs = (
K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
)
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
s = tl.where(kmask[None, :], s, -float("inf"))
# Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
causal_ok = nk[None, :] <= q_idx[:, None]
s = tl.where(causal_ok, s, -float("inf"))
# store prefix logits only (for marginal probs on prefix keys)
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
store_mask = kmask & (nk < k_eff_end)
tl.store(log_ptrs, s.T, mask=store_mask[:, None] & row_mask[None, :])
# log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
cur_max = tl.max(s, 1) # [BQ]
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
# store m,S for these rows
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
_lse_and_store_logits_kernel_autotuned = triton_autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
)
for bq in [32, 64]
for bk in [32, 64]
for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
cache_results=True,
)(_lse_and_store_logits_kernel)
@triton_autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64]
for bk in [32, 64, 128]
],
key=["HK", "HQ"],
cache_results=True,
)
@triton.jit
def _prefix_probs_kernel(
cu_k,
w_b,
in_m,
in_S, # [B, Hk, ROWS_MAX] f32
LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
PROBS, # [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
#
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_PB_NK,
STRIDE_PB_HK,
STRIDE_PB_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
m = tl.load(
m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"),
)
S = tl.load(
S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
s_T = tl.load(
log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
) # [BK, BQ]
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
prob_ptrs = (
PROBS
+ nk[:, None] * STRIDE_PB_NK
+ hk * STRIDE_PB_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_PB_R
)
tl.store(prob_ptrs, probs_T, mask=kmask[:, None] & rmask[None, :])
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
cache_results=True,
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT, # [Nk, Hk], float32
cu_k,
w_b, # [B+1], [B] int32
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr, # Hk
EPS: tl.constexpr, # e.g., 1e-12
BLOCK_K: tl.constexpr, # e.g., 128
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
@triton_autotune(
configs=[triton.Config({"BLOCK_T": bt}) for bt in [32, 64, 128, 256]],
key=["KERNEL_SIZE"],
cache_results=True,
)
@triton.jit
def _snapkv_avg_pool1d_kernel(
IN,
OUT,
Lp,
STRIDE_IN_C,
STRIDE_IN_L,
STRIDE_OUT_C,
STRIDE_OUT_L,
KERNEL_SIZE: tl.constexpr,
PAD: tl.constexpr,
BLOCK_T: tl.constexpr,
):
"""
Symmetric 1D average pool on the last dimension, matching
`F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
(equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
"""
c = tl.program_id(0)
t0 = tl.program_id(1) * BLOCK_T + tl.arange(0, BLOCK_T)
mask = t0 < Lp
acc = tl.zeros([BLOCK_T], dtype=tl.float32)
for j in tl.static_range(KERNEL_SIZE):
idx = t0 - PAD + j
valid = (idx >= 0) & (idx < Lp)
ptrs = IN + c * STRIDE_IN_C + idx * STRIDE_IN_L
v = tl.load(ptrs, mask=valid & mask, other=0.0).to(tl.float32)
acc += v
acc = acc / tl.cast(KERNEL_SIZE, tl.float32)
out_ptrs = OUT + c * STRIDE_OUT_C + t0 * STRIDE_OUT_L
tl.store(out_ptrs, acc, mask=mask)
def _snapkv_avg_pool1d_triton(x: torch.Tensor, kernel_size: int) -> torch.Tensor:
"""
kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
`x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
"""
assert x.dtype == torch.float32
Hk, G, Lp = x.shape
if Lp == 0:
return x
pad = kernel_size // 2
x2 = x.reshape(Hk * G, Lp).contiguous()
out = torch.empty_like(x2)
C = Hk * G
si_c, si_l = x2.stride()
so_c, so_l = out.stride()
def grid(meta):
return (C, triton.cdiv(Lp, meta["BLOCK_T"]))
_snapkv_avg_pool1d_kernel[grid](
x2,
out,
Lp,
si_c,
si_l,
so_c,
so_l,
KERNEL_SIZE=kernel_size,
PAD=pad,
)
return out.view(Hk, G, Lp)
def _snapkv_kvpress_epilogue(
probs_buf: torch.Tensor,
out: torch.Tensor,
cu_seqlens_k: torch.Tensor,
w: torch.Tensor,
G: int,
Hk: int,
kernel_size: int,
) -> None:
"""
Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
→ mean over GQA groups → pad tail with global max of prefix scores.
"""
B = cu_seqlens_k.numel() - 1
for b in range(B):
k_beg = int(cu_seqlens_k[b].item())
k_end = int(cu_seqlens_k[b + 1].item())
win = int(w[b].item())
k_eff_end = k_end - win
if win <= 0 or k_eff_end <= k_beg:
continue
Lp = k_eff_end - k_beg
rows_b = win * G
p = probs_buf[k_beg:k_eff_end, :, :rows_b]
# [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
x = p.view(Lp, Hk, win, G).mean(dim=2)
x = x.permute(1, 2, 0).contiguous() # [Hk, G, Lp]
x = _snapkv_avg_pool1d_triton(x, kernel_size)
x = x.mean(dim=1)
seg = x.permute(1, 0).contiguous()
out[k_beg:k_eff_end, :] = seg
pad_val = seg.max()
out[k_eff_end:k_end, :] = pad_val
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1], int32
cu_seqlens_k: torch.Tensor, # [B+1], int32
w: torch.Tensor | int, # [B], int32
sm_scale: float = None, # defaults to 1/sqrt(D)
*,
kernel_size: int = DEFAULT_SNAPKV_KERNEL_SIZE,
accum_scores: torch.Tensor = None,
accum_blending: float = None,
normalize: bool = False,
) -> Optional[torch.Tensor]:
assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = cu_seqlens_q.numel() - 1
assert B == cu_seqlens_k.numel() - 1
G = Hq // Hk
if type(w) is int:
max_w = w
w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
else:
max_w = int(w.max().item())
assert w.numel() == B
ROWS_MAX = max_w * G
if ROWS_MAX == 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
probs_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
# strides
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R = probs_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(META):
return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
lse_kernel = _lse_and_store_logits_kernel_autotuned
lse_kernel_kwargs = {}
if _USE_ROCM_TRITON_DOT_WORKAROUND:
lse_kernel = _lse_and_store_logits_kernel
lse_kernel_kwargs = {
"BLOCK_Q": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_Q,
"BLOCK_K": _ROCM_TRITON_DOT_WORKAROUND_BLOCK_K,
"num_warps": 4,
"num_stages": 1,
}
lse_kernel[grid](
q,
k,
cu_seqlens_q,
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=Hq // Hk,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=ROWS_MAX,
ROCM_DOT_LAYOUT_WORKAROUND=_USE_ROCM_TRITON_DOT_WORKAROUND,
**lse_kernel_kwargs,
)
_prefix_probs_kernel[(B, Hk)](
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
probs_buf,
QUERY_GROUP_SIZE=Hq // Hk,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_PB_NK=STRIDE_PB_NK,
STRIDE_PB_HK=STRIDE_PB_HK,
STRIDE_PB_R=STRIDE_PB_R,
)
_snapkv_kvpress_epilogue(
probs_buf, out, cu_seqlens_k, w, G, Hk, kernel_size
)
if normalize:
_zscore_per_batch_epilogue[(B,)](
out,
cu_seqlens_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
if accum_scores is not None:
if accum_blending is not None:
accum_scores.mul_(accum_blending)
accum_scores.add_(out)
return accum_scores
else:
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment