Commit 5036e878 authored by laibao's avatar laibao
Browse files

feat: kvpress新增 SnapKV 打分与 KV compaction Triton 内核

parent d3acd4a5
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from typing import Optional, Tuple
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import triton
import triton.language as tl
def _require_triton() -> None:
if not HAS_TRITON:
raise RuntimeError("Triton is not available.")
def _check_cuda(*tensors: torch.Tensor) -> None:
for t in tensors:
if not isinstance(t, torch.Tensor):
raise TypeError("Expected torch.Tensor inputs.")
if t.device.type != "cuda":
raise RuntimeError("Triton KV cache ops require CUDA/ROCm tensors.")
@triton.autotune(
configs=[
triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
],
key=["D"],
)
@triton.jit
def _gather_k_to_packed_kernel(
K_ptr,
out_ptr,
blk_ids_ptr,
req_blk_starts_ptr,
cu_seqlens_ptr,
seq_lens_ptr,
B,
H,
max_blocks,
block_size,
D,
sKb,
sKh,
sKt,
sKd,
so_t,
so_h,
so_d,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_t = tl.program_id(1)
pid_d = tl.program_id(2)
b = pid_bh // H
h = pid_bh % H
if b >= B:
return
seq_len = tl.load(seq_lens_ptr + b)
if seq_len <= 0:
return
t0 = pid_t * BLOCK_T
t_range = t0 + tl.arange(0, BLOCK_T)
t_mask = t_range < seq_len
d0 = pid_d * BLOCK_D
d_range = d0 + tl.arange(0, BLOCK_D)
d_mask = d_range < D
# Map logical token indices -> physical block ids.
blk = t_range // block_size
inb = t_range - blk * block_size
req_blk_start = tl.load(req_blk_starts_ptr + b)
gblk = req_blk_start + blk
# Guard against out-of-range block indices (should not happen when block_table
# covers the sequence length).
gblk_safe = tl.where(t_mask, gblk, 0)
bid = tl.load(blk_ids_ptr + gblk_safe, mask=t_mask, other=0)
# Source: key cache layout [num_blocks, H, block_size, D]
src_base = K_ptr + bid[:, None] * sKb + h * sKh + inb[:, None] * sKt
src_ptrs = src_base + d_range[None, :] * sKd
# Destination: packed output layout [T, H, D]
out_start = tl.load(cu_seqlens_ptr + b)
dst_base = out_ptr + (out_start + t_range)[:, None] * so_t + h * so_h
dst_ptrs = dst_base + d_range[None, :] * so_d
tile = tl.load(src_ptrs, mask=(t_mask[:, None] & d_mask[None, :]), other=0)
tl.store(dst_ptrs, tile, mask=(t_mask[:, None] & d_mask[None, :]))
@torch.inference_mode()
def gather_k_to_packed_triton(
key_cache: torch.Tensor,
block_table: torch.Tensor,
seq_lens: torch.Tensor,
cu_seqlens: torch.Tensor,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Gather a block-wise KV key cache into a packed [T, H, D] tensor.
Expected layouts:
- key_cache: [num_blocks, H, block_size, D]
- block_table: [B, max_blocks] int32 physical block ids
- seq_lens: [B] int32 logical lengths (tokens) to gather
- cu_seqlens: [B+1] int32 cumulative offsets into the packed output
"""
_require_triton()
_check_cuda(key_cache, block_table, seq_lens, cu_seqlens)
if key_cache.ndim != 4:
raise ValueError("key_cache must be a 4D tensor [num_blocks, H, Tb, D].")
if block_table.ndim != 2:
raise ValueError("block_table must be 2D [B, max_blocks].")
if seq_lens.ndim != 1:
raise ValueError("seq_lens must be 1D [B].")
if cu_seqlens.ndim != 1:
raise ValueError("cu_seqlens must be 1D [B+1].")
device = key_cache.device
B = int(seq_lens.numel())
if B == 0:
return torch.empty((0, int(key_cache.shape[1]), int(key_cache.shape[3])),
device=device,
dtype=key_cache.dtype)
H = int(key_cache.shape[1])
block_size = int(key_cache.shape[2])
D = int(key_cache.shape[3])
max_blocks = int(block_table.shape[1])
seq_lens_i32 = seq_lens.to(device=device, dtype=torch.int32)
cu_i32 = cu_seqlens.to(device=device, dtype=torch.int32)
total_tokens = int(cu_i32[-1].item()) if cu_i32.numel() > 0 else 0
if out is None:
out = torch.empty((total_tokens, H, D), device=device, dtype=key_cache.dtype)
else:
if out.shape != (total_tokens, H, D):
raise ValueError(
f"out has shape {tuple(out.shape)}, expected {(total_tokens, H, D)}."
)
blk_ids = block_table.to(device=device, dtype=torch.int32).reshape(-1)
req_starts = (torch.arange(B, device=device, dtype=torch.int32) * max_blocks)
sKb, sKh, sKt, sKd = [int(s) for s in key_cache.stride()]
so_t, so_h, so_d = [int(s) for s in out.stride()]
L_max = int(seq_lens_i32.max().item()) if B > 0 else 0
if total_tokens == 0 or L_max == 0 or D == 0 or H == 0:
return out
# Use the smallest tile sizes across autotune configs to guarantee coverage
# even when the selected config uses smaller blocks.
grid = (
B * H,
triton.cdiv(L_max, 128),
triton.cdiv(D, 64),
)
_gather_k_to_packed_kernel[grid](
key_cache,
out,
blk_ids,
req_starts,
cu_i32,
seq_lens_i32,
B,
H,
max_blocks,
block_size,
D,
sKb,
sKh,
sKt,
sKd,
so_t,
so_h,
so_d,
)
return out
@triton.autotune(
configs=[
triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 512, 'BLOCK_D': 64}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
],
key=['K_max', 'Dk'],
)
@triton.jit
def _front_compact_inplace_fa_k_kernel(
K_ptr,
blk_ids_ptr,
req_blk_starts_ptr,
idx_ptr,
keep_ptr,
B,
H,
K_max,
block_size,
Dk,
sKb,
sKh,
sKt,
sKd,
si_b,
si_h,
si_k,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_d = tl.program_id(1)
b = pid_bh // H
h = pid_bh % H
if b >= B:
return
d0 = pid_d * BLOCK_D
d_range = d0 + tl.arange(0, BLOCK_D)
d_mask = d_range < Dk
d_safe = tl.where(d_mask, d_range, 0)
keep_b = tl.load(keep_ptr + b)
if keep_b <= 0:
return
req_blk_start = tl.load(req_blk_starts_ptr + b)
k0 = 0
while k0 < keep_b:
k_range = k0 + tl.arange(0, BLOCK_T)
k_mask = (k_range < K_max) & (k_range < keep_b)
k_safe = tl.where(k_mask, k_range, 0)
idx_base = idx_ptr + b * si_b + h * si_h + k_safe * si_k
t_src = tl.load(idx_base, mask=k_mask, other=0)
# No-op copies (src == dst) can be skipped safely because idx_sorted is
# ascending, so we always copy from later/equal positions to earlier.
t_dst = k_safe
copy_mask = k_mask & (t_src != t_dst)
blk_src = t_src // block_size
inb_src = t_src % block_size
gblk_src = req_blk_start + blk_src
bid_src = tl.load(blk_ids_ptr + gblk_src, mask=copy_mask, other=0)
blk_dst = t_dst // block_size
inb_dst = t_dst % block_size
gblk_dst = req_blk_start + blk_dst
bid_dst = tl.load(blk_ids_ptr + gblk_dst, mask=copy_mask, other=0)
src_base = K_ptr + bid_src[:, None] * sKb + h * sKh + inb_src[:, None] * sKt
src_ptrs = src_base + d_safe[None, :] * sKd
dst_base = K_ptr + bid_dst[:, None] * sKb + h * sKh + inb_dst[:, None] * sKt
dst_ptrs = dst_base + d_safe[None, :] * sKd
tile = tl.load(src_ptrs,
mask=(copy_mask[:, None] & d_mask[None, :]),
other=0)
tl.store(dst_ptrs, tile, mask=(copy_mask[:, None] & d_mask[None, :]))
k0 += BLOCK_T
@triton.autotune(
configs=[
triton.Config({'BLOCK_T': 128, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 64}, num_warps=4, num_stages=2),
triton.Config({'BLOCK_T': 512, 'BLOCK_D': 64}, num_warps=8, num_stages=2),
triton.Config({'BLOCK_T': 256, 'BLOCK_D': 128}, num_warps=8, num_stages=2),
],
key=['K_max', 'Dv'],
)
@triton.jit
def _front_compact_inplace_fa_v_kernel(
V_ptr,
blk_ids_ptr,
req_blk_starts_ptr,
idx_ptr,
keep_ptr,
B,
H,
K_max,
block_size,
Dv,
sv_b,
sv_h,
sv_d,
sv_t,
si_b,
si_h,
si_k,
BLOCK_T: tl.constexpr,
BLOCK_D: tl.constexpr,
):
pid_bh = tl.program_id(0)
pid_d = tl.program_id(1)
b = pid_bh // H
h = pid_bh % H
if b >= B:
return
d0 = pid_d * BLOCK_D
d_range = d0 + tl.arange(0, BLOCK_D)
d_mask = d_range < Dv
d_safe = tl.where(d_mask, d_range, 0)
keep_b = tl.load(keep_ptr + b)
if keep_b <= 0:
return
req_blk_start = tl.load(req_blk_starts_ptr + b)
k0 = 0
while k0 < keep_b:
k_range = k0 + tl.arange(0, BLOCK_T)
k_mask = (k_range < K_max) & (k_range < keep_b)
k_safe = tl.where(k_mask, k_range, 0)
idx_base = idx_ptr + b * si_b + h * si_h + k_safe * si_k
t_src = tl.load(idx_base, mask=k_mask, other=0)
t_dst = k_safe
copy_mask = k_mask & (t_src != t_dst)
blk_src = t_src // block_size
inb_src = t_src % block_size
gblk_src = req_blk_start + blk_src
bid_src = tl.load(blk_ids_ptr + gblk_src, mask=copy_mask, other=0)
blk_dst = t_dst // block_size
inb_dst = t_dst % block_size
gblk_dst = req_blk_start + blk_dst
bid_dst = tl.load(blk_ids_ptr + gblk_dst, mask=copy_mask, other=0)
# value layout: [num_blocks, H, Dv, block_size]
v_src_base = V_ptr + bid_src[:, None] * sv_b + h * sv_h + d_safe[None, :] * sv_d
v_src_ptrs = v_src_base + inb_src[:, None] * sv_t
v_dst_base = V_ptr + bid_dst[:, None] * sv_b + h * sv_h + d_safe[None, :] * sv_d
v_dst_ptrs = v_dst_base + inb_dst[:, None] * sv_t
tile = tl.load(v_src_ptrs,
mask=(copy_mask[:, None] & d_mask[None, :]),
other=0)
tl.store(v_dst_ptrs, tile, mask=(copy_mask[:, None] & d_mask[None, :]))
k0 += BLOCK_T
@torch.inference_mode()
def front_compact_inplace_fa_triton(
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_table: torch.Tensor,
idx_sorted: torch.Tensor,
keep: torch.Tensor,
) -> None:
"""In-place front compaction for FlashAttention KV cache.
Moves selected time indices to the front [0..keep[b]) per request for both
key_cache and value_cache in-place.
Expected layouts:
- key_cache: [num_blocks, H, block_size, Dk]
- value_cache: [num_blocks, H, Dv, block_size]
- block_table: [B, max_blocks] int32 physical block ids
- idx_sorted: [B, K] int32 or [B, H, K] int32 (ascending indices)
- keep: [B] int32 (<= K), number of kept tokens per request
"""
_require_triton()
_check_cuda(key_cache, value_cache, block_table, idx_sorted, keep)
if key_cache.ndim != 4 or value_cache.ndim != 4:
raise ValueError("key_cache/value_cache must be 4D tensors.")
if block_table.ndim != 2:
raise ValueError("block_table must be 2D [B, max_blocks].")
if idx_sorted.ndim not in (2, 3):
raise ValueError("idx_sorted must be 2D [B,K] or 3D [B,H,K].")
if keep.ndim != 1:
raise ValueError("keep must be 1D [B].")
device = key_cache.device
B = int(block_table.shape[0])
if B == 0:
return
H = int(key_cache.shape[1])
block_size = int(key_cache.shape[2])
Dk = int(key_cache.shape[3])
Dv = int(value_cache.shape[2])
if idx_sorted.ndim == 2:
idx_sorted = idx_sorted[:, None, :].expand(-1, H, -1)
K_max = int(idx_sorted.shape[2])
if K_max == 0:
return
blk_ids = block_table.to(device=device, dtype=torch.int32).reshape(-1)
max_blocks = int(block_table.shape[1])
req_starts = (torch.arange(B, device=device, dtype=torch.int32) * max_blocks)
idx_i32 = idx_sorted.to(device=device, dtype=torch.int32)
keep_i32 = keep.to(device=device, dtype=torch.int32)
sKb, sKh, sKt, sKd = [int(s) for s in key_cache.stride()]
sv_b, sv_h, sv_d, sv_t = [int(s) for s in value_cache.stride()]
si_b, si_h, si_k = [int(s) for s in idx_i32.stride()]
if Dk > 0:
grid_k = (
B * H,
triton.cdiv(Dk, 64),
)
_front_compact_inplace_fa_k_kernel[grid_k](
key_cache,
blk_ids,
req_starts,
idx_i32,
keep_i32,
B,
H,
K_max,
block_size,
Dk,
sKb,
sKh,
sKt,
sKd,
si_b,
si_h,
si_k,
)
if Dv > 0:
grid_v = (
B * H,
triton.cdiv(Dv, 64),
)
_front_compact_inplace_fa_v_kernel[grid_v](
value_cache,
blk_ids,
req_starts,
idx_i32,
keep_i32,
B,
H,
K_max,
block_size,
Dv,
sv_b,
sv_h,
sv_d,
sv_t,
si_b,
si_h,
si_k,
)
def make_fa_cache_view(
*,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Return (K_view, V_view) in the canonical FA compaction layout.
- K_view: [num_blocks, H, block_size, D]
- V_view: [num_blocks, H, D, block_size]
"""
if key_cache.ndim != 4 or value_cache.ndim != 4:
raise ValueError("key_cache/value_cache must be 4D tensors.")
# ROCm path (FlashAttention v1): K=[B,H,T,D] and V=[B,H,D,T]
if (value_cache.shape[3] == key_cache.shape[2]
and value_cache.shape[2] == key_cache.shape[3]):
k_view = key_cache
v_view = value_cache
else:
# CUDA path: K=[B,T,H,D] and V=[B,T,H,D]
k_view = key_cache.permute(0, 2, 1, 3)
v_view = value_cache.permute(0, 2, 3, 1)
return k_view, v_view
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
from typing import Optional, Union
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import triton
import triton.language as tl
if HAS_TRITON:
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk},
num_warps=num_warps,
num_stages=num_stages,
) for bq in [32, 64] for bk in [32, 64] for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0)
s = tl.dot(q_rows, k_blk.T) * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton.jit
def _lse_and_store_logits_kernel_rocm_safe(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""ROCm-safe variant of `_lse_and_store_logits_kernel`.
On some ROCm + Triton (HIP) stacks we have observed memory corruption
from the tl.dot-based implementation. This variant avoids `tl.dot` and
instead computes dot-products via explicit outer-product accumulation.
"""
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
# Accumulate s = Q @ K^T in fp32 via outer products.
s = tl.zeros([BLOCK_Q, BLOCK_K], dtype=tl.float32)
for ds in tl.static_range(0, D, BLOCK_D):
offs_d = ds + tl.arange(0, BLOCK_D)
dmask = offs_d < D
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_chunk = tl.load(
q_ptrs,
mask=row_mask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BQ, BD]
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_chunk = tl.load(
k_ptrs,
mask=kmask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BK, BD]
s += tl.sum(q_chunk[:, None, :] * k_chunk[None, :, :], axis=2)
s = s * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
m,
mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
S,
mask=row_mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64] for bk in [32, 64, 128]
],
key=["HK", "HQ"],
)
@triton.jit
def _scores_from_logits_kernel(
cu_k,
w_b,
in_m,
in_S,
LOGITS,
OUT,
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
DO_POOL: tl.constexpr,
KPOOL: tl.constexpr,
PROTECT_LAST: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
scores = tl.zeros([BLOCK_K], dtype=tl.float32)
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = (in_m + b * STRIDE_M_B + hk * STRIDE_M_H +
row0 * STRIDE_M_R)
S_ptr = (in_S + b * STRIDE_S_B + hk * STRIDE_S_H +
row0 * STRIDE_S_R)
m = tl.load(m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"))
S = tl.load(S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
mask=rmask,
other=0.0)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] *
STRIDE_LG_R)
s_T = tl.load(log_ptrs,
mask=kmask[:, None] & rmask[None, :],
other=-float("inf"))
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
scores += tl.sum(probs_T, 1)
if DO_POOL and (KPOOL > 1):
i = tl.arange(0, BLOCK_K)[:, None]
j = tl.arange(0, BLOCK_K)[None, :]
band = (j <= i) & ((i - j) < KPOOL)
band = band & kmask[None, :]
sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1)
denom = tl.sum(band, 1).to(tl.float32)
denom = tl.where(denom > 0, denom, 1.0)
scores = sums / denom
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, scores, mask=kmask)
if PROTECT_LAST:
pad_beg = k_eff_end
pad_end = k_end
if pad_end > pad_beg:
for ks in tl.range(pad_beg, pad_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < pad_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs,
tl.full([BLOCK_K], float("inf"), dtype=tl.float32),
mask=kmask)
@triton.autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT,
cu_k,
w_b,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr,
EPS: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1] int32
cu_seqlens_k: torch.Tensor, # [B+1] int32
w: Union[int, torch.Tensor], # [B] int32 or scalar
sm_scale: Optional[float] = None,
*,
pool: bool = True,
kpool: int = 5,
protect_last: bool = True,
normalize: bool = False,
) -> torch.Tensor:
"""SnapKV query-aware key scores (Triton), returns [N_k, Hk] float32."""
if not HAS_TRITON:
raise RuntimeError("Triton is not available.")
if q.device.type != "cuda" or k.device.type != "cuda":
raise RuntimeError("Triton SnapKV requires CUDA/ROCm tensors.")
if q.ndim != 3 or k.ndim != 3:
raise ValueError("q and k must be 3D tensors.")
if q.stride(-1) != 1 or k.stride(-1) != 1:
raise ValueError("Last dim must be contiguous for Triton SnapKV.")
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
if D != Dk:
raise ValueError("q and k must have the same head size.")
if (Hq % Hk) != 0:
raise ValueError("Hq must be a multiple of Hk.")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = int(cu_seqlens_q.numel() - 1)
if B != int(cu_seqlens_k.numel() - 1):
raise ValueError("cu_seqlens_q and cu_seqlens_k must match.")
G = Hq // Hk
if isinstance(w, int):
max_w = int(w)
w = torch.full((B, ),
fill_value=max_w,
device=device,
dtype=torch.int32)
else:
if w.numel() != B:
raise ValueError("w must have shape [B].")
w = w.to(device=device, dtype=torch.int32)
max_w = int(w.max().item())
rows_max = max_w * G
if rows_max <= 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
if kpool < 1:
raise ValueError("kpool must be >= 1.")
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
S_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
logits_buf = torch.empty((N_k, Hk, rows_max),
dtype=torch.float32,
device=device)
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(meta):
return B, Hk, triton.cdiv(rows_max, meta["BLOCK_Q"])
cu_q = cu_seqlens_q.to(device=device, dtype=torch.int32)
cu_k = cu_seqlens_k.to(device=device, dtype=torch.int32)
# NOTE: On ROCm/HIP, we prefer a dot-free kernel variant to avoid known
# correctness issues (silent memory corruption) observed with the tl.dot
# implementation on some stacks.
is_rocm = getattr(torch.version, "hip", None) is not None
if is_rocm:
_lse_and_store_logits_kernel_rocm_safe[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
BLOCK_Q=32,
BLOCK_K=32,
BLOCK_D=16,
num_warps=4,
num_stages=1,
)
else:
_lse_and_store_logits_kernel[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=rows_max,
)
_scores_from_logits_kernel[(B, Hk)](
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
out,
QUERY_GROUP_SIZE=G,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_OUT_NK=STRIDE_OUT_NK,
STRIDE_OUT_HK=STRIDE_OUT_HK,
DO_POOL=pool,
KPOOL=kpool,
PROTECT_LAST=protect_last,
)
if normalize:
_zscore_per_batch_epilogue[(B, )](
out,
cu_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment