Commit 2b7160c6 authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.0

parent fa718036
from collections.abc import Callable
import torch
def maybe_execute_in_stream(
fn: Callable, *args, STORE_STREAM: torch.cuda.Stream = None, **kwargs
):
if STORE_STREAM is not None:
tensors = [arg for arg in args if isinstance(arg, torch.Tensor)]
tensors += [val for val in kwargs.values() if isinstance(val, torch.Tensor)]
obj = getattr(fn, "__self__", None)
if isinstance(obj, torch.Tensor):
tensors.append(obj)
STORE_STREAM.wait_stream(torch.cuda.default_stream())
# Some PyTorch builds don't make `torch.cuda.Stream` a context manager.
# The portable API is `torch.cuda.stream(stream)`.
stream_ctx = (
STORE_STREAM
if hasattr(STORE_STREAM, "__enter__")
else torch.cuda.stream(STORE_STREAM)
)
with stream_ctx:
output = fn(*args, **kwargs)
for t in tensors:
t.record_stream(STORE_STREAM)
if isinstance(output, tuple):
for o in output:
if isinstance(o, torch.Tensor):
o.record_stream(torch.cuda.default_stream())
elif isinstance(output, torch.Tensor):
output.record_stream(torch.cuda.default_stream())
return output
else:
return fn(*args, **kwargs)
from dataclasses import dataclass, field
from enum import Enum, auto
from itertools import count
from typing import List
from compactor_vllm.compression.compression_config import SequenceCompressionParams
from compactor_vllm.config.sampling_params import SamplingParams
class SequenceStatus(Enum):
WAITING = auto()
RUNNING = auto()
FINISHED = auto()
@dataclass
class Sequence:
"""
Represents a single user request / sequence being generated.
"""
_counter = count()
prompt_token_ids: List[int]
completion_token_ids: List[int] = field(default_factory=list)
sampling_params: SamplingParams = field(default_factory=SamplingParams)
compression_params: SequenceCompressionParams = field(
default_factory=SequenceCompressionParams
)
status: SequenceStatus = SequenceStatus.WAITING
seq_id: int = field(default_factory=lambda: next(Sequence._counter), init=False)
num_tokens_processed: int = 0
@property
def num_prompt_tokens(self) -> int:
return len(self.prompt_token_ids)
@property
def num_generated_tokens(self) -> int:
return len(self.completion_token_ids)
def add_new_token(self, token_id: int) -> None:
if len(self.completion_token_ids) == 0:
self.num_tokens_processed += self.num_prompt_tokens
self.completion_token_ids.append(token_id)
self.num_tokens_processed += 1
def tokens_to_retain_per_layer(self, num_kv_heads: int) -> int:
n = int(
self.compression_params.compression_ratio
* self.num_prompt_tokens
* num_kv_heads
)
return max(1, n)
def __getstate__(self):
return dict(
prompt_token_ids=list(self.prompt_token_ids),
completion_token_ids=list(self.completion_token_ids),
sampling_params=self.sampling_params,
compression_params=self.compression_params,
status=self.status,
seq_id=self.seq_id,
num_tokens_processed=self.num_tokens_processed,
)
def __setstate__(self, state):
self.prompt_token_ids = list(state["prompt_token_ids"])
self.completion_token_ids = list(state["completion_token_ids"])
self.sampling_params = state["sampling_params"]
self.compression_params = state["compression_params"]
self.status = state["status"]
self.seq_id = state["seq_id"]
self.num_tokens_processed = state["num_tokens_processed"]
@property
def prompt_len(self) -> int:
return len(self.prompt_token_ids)
@property
def completion_len(self) -> int:
return len(self.completion_token_ids)
from __future__ import annotations
import inspect
from typing import Any, Callable, Mapping
import torch
def _filter_kwargs_for_callable(
fn: Callable[..., Any], kwargs: Mapping[str, Any]
) -> dict[str, Any]:
try:
params = inspect.signature(fn).parameters
except (TypeError, ValueError):
return dict(kwargs)
return {k: v for k, v in kwargs.items() if k in params}
def autotune(*, configs, key, **kwargs):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import triton
filtered = _filter_kwargs_for_callable(triton.autotune, kwargs)
return triton.autotune(configs=configs, key=key, **filtered)
def maybe_set_allocator(alloc_fn: Callable[[int, int, int | None], Any]) -> bool:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import triton
setter = getattr(triton, "set_allocator", None)
if setter is None:
return False
setter(alloc_fn)
return True
def cuda_capability_geq(major: int, minor: int = 0, device: int | None = None) -> bool:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if not torch.cuda.is_available():
return False
if device is None:
try:
device = torch.cuda.current_device()
except Exception:
device = 0
cap = torch.cuda.get_device_capability(device)
return cap >= (major, minor)
import collections
import logging
from dataclasses import dataclass
from typing import List
import pytest
import torch
import triton
from compactor_vllm.compression.common import scores_to_retain_indices
from src.compactor_vllm.kv_cache.store_kv_cache import prefill_store_topk_kv
logger = logging.getLogger(__name__)
@dataclass
class Workload:
name: str
batch_size: int
nk_heads: int
head_dim: int
frac: float # per-sequence cached context length fractionf
page_size: int
cache_lens: List[int] # per-sequence cached context length
WORKLOADS: List[Workload] = [
Workload(
name=f"batch_size={BATCH} kv_cache_len={cache_lens} "
f"FRAC={frac} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
batch_size=BATCH,
nk_heads=NK_HEADS,
head_dim=HEAD_DIM,
cache_lens=[cache_lens] * BATCH,
frac=frac,
page_size=ps,
)
for BATCH in [1, 2, 3, 8]
for frac in [0.10, 0.20, 0.30, 0.40]
for NK_HEADS in [2, 4, 8]
for HEAD_DIM in [32, 64, 128]
for cache_lens in [10, 20, 30, 70, 1000]
for ps in [128, 256]
]
@pytest.mark.parametrize("workload", WORKLOADS, ids=lambda wl: wl.name)
def test_prefill_store_topk_kv(workload: Workload):
B = workload.batch_size
H = workload.nk_heads
D = workload.head_dim
TOP_K = int(workload.cache_lens[0] * workload.nk_heads * workload.frac)
PAGE_SIZE = workload.page_size
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
lens = torch.tensor(workload.cache_lens, dtype=torch.int32, device=device)
cu = torch.zeros(B + 1, dtype=torch.int32, device=device)
cu[1:] = torch.cumsum(lens, dim=0)
N_total = int(cu[-1].item())
keys = torch.randn((N_total, H, D), dtype=dtype, device=device)
vals = torch.randn_like(keys)
scores_flat = torch.randn((N_total, H), dtype=torch.float32, device=device)
top_k_eff = max(0, min(TOP_K, int(lens.max().item()) * H))
max_k_len = cu.diff().max().item()
indices = scores_to_retain_indices(
scores_flat, cu, max_k_len, top_k_eff, H
) # [B, TOP_K]
LP = max(1, (top_k_eff + PAGE_SIZE - 1) // PAGE_SIZE)
N_LOGICAL_PAGES_MAX = LP
N_PAGES = B * H * LP + 32
S_LARGE = N_PAGES * PAGE_SIZE
k_cache = torch.empty((S_LARGE, D), dtype=dtype, device=device)
v_cache = torch.empty_like(k_cache)
page_table = torch.empty(
(B, H, N_LOGICAL_PAGES_MAX), dtype=torch.int32, device=device
)
phys = 0
for b in range(B):
for h in range(H):
for lp in range(LP):
page_table[b, h, lp] = phys
phys += 1
assert phys <= N_PAGES, "Not enough physical pages"
local_lens = torch.zeros((B, H), dtype=torch.int32, device=device)
batch_mapping = torch.arange(B, dtype=torch.int32, device=device)
num_to_retain = torch.full((B,), top_k_eff, dtype=torch.int32, device=device)
prefill_store_topk_kv(
new_keys=keys,
new_vals=vals,
indices_topk=indices,
num_tokens_to_retain=num_to_retain,
page_table=page_table,
batch_mapping=batch_mapping,
bh_lens=local_lens,
PAGE_SIZE=PAGE_SIZE,
k_cache=k_cache,
v_cache=v_cache,
PAD_TO_PAGE_SIZE=False,
TRITON_RESERVED_BATCH=-1,
)
torch.cuda.synchronize()
local_lens_cpu = local_lens.cpu()
page_table_cpu = page_table.cpu()
k_cache_cpu = k_cache.cpu()
v_cache_cpu = v_cache.cpu()
keys_cpu = keys.cpu()
vals_cpu = vals.cpu()
indices_cpu = indices.cpu()
for b in range(B):
hed = (indices_cpu[b] % H).numpy()
counts = collections.Counter(hed.tolist())
for h in range(H):
expected = counts.get(h, 0) # type: ignore
got = int(local_lens_cpu[b, h].item())
assert got == expected, (
f"Length mismatch at (b={b}, h={h}): got {got}, expected {expected}"
)
def rows_for_head(b, h, L):
"""Return the list of cache row indices storing the first L logical positions for (b,h)."""
rows = []
for pos in range(L):
lp = pos // PAGE_SIZE
off = pos % PAGE_SIZE
phys = int(page_table_cpu[b, h, lp].item())
rows.append(phys * PAGE_SIZE + off)
return rows
for b in range(B):
# which tokens per head were selected for this batch?
tok = (indices_cpu[b] // H).numpy()
hed = (indices_cpu[b] % H).numpy()
per_head = collections.defaultdict(list)
for t, h in zip(tok, hed):
per_head[int(h)].append(int(t))
for h in range(H):
L = int(local_lens_cpu[b, h].item())
if L == 0:
continue
# expected vectors (unordered) from source
toks_h = per_head.get(h, [])
assert len(toks_h) == L
expK = keys_cpu[toks_h, h, :].contiguous().view(L, -1)
expV = vals_cpu[toks_h, h, :].contiguous().view(L, -1)
# actual vectors read back from cache rows
rows = rows_for_head(b, h, L)
actK = k_cache_cpu[rows, :].contiguous().view(L, -1)
actV = v_cache_cpu[rows, :].contiguous().view(L, -1)
expK_tuples = [tuple(row) for row in expK.numpy().tolist()]
actK_tuples = [tuple(row) for row in actK.numpy().tolist()]
expV_tuples = [tuple(row) for row in expV.numpy().tolist()]
actV_tuples = [tuple(row) for row in actV.numpy().tolist()]
assert collections.Counter(expK_tuples) == collections.Counter(
actK_tuples
), f"K content mismatch at (b={b}, h={h})"
assert collections.Counter(expV_tuples) == collections.Counter(
actV_tuples
), f"V content mismatch at (b={b}, h={h})"
def test_prefill_store_topk_kv_pad_to_page_size():
torch.manual_seed(0)
B, H, D = 2, 2, 64
PAGE_SIZE = 128
RETAIN = 64
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
lens = torch.full((B,), 256, dtype=torch.int32, device=device)
cu = torch.zeros(B + 1, dtype=torch.int32, device=device)
cu[1:] = torch.cumsum(lens, dim=0)
N_total = int(cu[-1].item())
keys = torch.randn((N_total, H, D), dtype=dtype, device=device)
vals = torch.randn_like(keys)
scores_flat = torch.randn((N_total, H), dtype=torch.float32, device=device)
max_k_len = int(lens.max().item())
max_sel = max_k_len * H
indices = scores_to_retain_indices(scores_flat, cu, max_k_len, max_sel, H)
N_LOGICAL_PAGES_MAX = 2
N_PAGES = B * H * N_LOGICAL_PAGES_MAX + 32
S_LARGE = N_PAGES * PAGE_SIZE
k_cache = torch.empty((S_LARGE, D), dtype=dtype, device=device)
v_cache = torch.empty_like(k_cache)
page_table = torch.empty(
(B, H, N_LOGICAL_PAGES_MAX), dtype=torch.int32, device=device
)
phys = 0
for b in range(B):
for h in range(H):
for lp in range(N_LOGICAL_PAGES_MAX):
page_table[b, h, lp] = phys
phys += 1
assert phys <= N_PAGES, "Not enough physical pages"
local_lens = torch.zeros((B, H), dtype=torch.int32, device=device)
batch_mapping = torch.arange(B, dtype=torch.int32, device=device)
num_to_retain = torch.full((B,), RETAIN, dtype=torch.int32, device=device)
prefill_store_topk_kv(
new_keys=keys,
new_vals=vals,
indices_topk=indices,
num_tokens_to_retain=num_to_retain,
page_table=page_table,
batch_mapping=batch_mapping,
bh_lens=local_lens,
PAGE_SIZE=PAGE_SIZE,
k_cache=k_cache,
v_cache=v_cache,
PAD_TO_PAGE_SIZE=True,
cu_seqlens_k=cu,
TRITON_RESERVED_BATCH=-1,
)
torch.cuda.synchronize()
local_lens_cpu = local_lens.cpu()
lens_cpu = lens.cpu()
assert (local_lens_cpu % PAGE_SIZE == 0).all()
assert (local_lens_cpu <= lens_cpu[:, None]).all()
import logging
import math
from dataclasses import dataclass
from typing import List
import pytest
import torch
import triton
from flash_attn.flash_attn_interface import (
flash_attn_varlen_func,
flash_attn_with_kvcache,
)
from compactor_vllm.attention.sparse_decode_kernel import head_sparse_decode_attention
from compactor_vllm.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
logger = logging.getLogger(__name__)
@dataclass
class Workload:
name: str
batch_size: int
nq_heads: int
nk_heads: int
head_dim: int
cache_lens: List[int] # per-sequence cached context length
append_lens: List[int] # per-sequence new tokens this step (Q_app, K_app, V_app)
WORKLOADS: List[Workload] = [
Workload(
name=f"batch_size={BATCH} kv_cache_len={cache_lens} append_len={append_lens} "
f"HQ={NQ_HEADS} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
batch_size=BATCH,
nq_heads=NQ_HEADS,
nk_heads=NK_HEADS,
head_dim=HEAD_DIM,
cache_lens=[cache_lens] * BATCH,
append_lens=[append_lens] * BATCH,
)
for BATCH in [1, 2, 3, 8]
for NQ_HEADS in [32]
for NK_HEADS in [8]
for HEAD_DIM in [128]
for cache_lens in [0, 1, 70, 128, 8193]
for append_lens in [1, 2, 13, 8000]
]
WORKLOADS_DECODE: List[Workload] = [
Workload(
name=f"batch_size={BATCH} kv_cache_len={cache_lens}"
f"HQ={NQ_HEADS} HKV={NK_HEADS} HEAD_DIM={HEAD_DIM}",
batch_size=BATCH,
nq_heads=NQ_HEADS,
nk_heads=NK_HEADS,
head_dim=HEAD_DIM,
cache_lens=[cache_lens] * BATCH,
append_lens=[1] * BATCH,
)
for BATCH in [1, 2, 3, 8]
for NQ_HEADS in [32]
for NK_HEADS in [8]
for HEAD_DIM in [128]
for cache_lens in [1, 2, 70, 128, 8000]
]
def build_paged_cache_from_lengths(
B,
H_kv,
D,
PAGE_SIZE,
N_LOGICAL_PAGES_MAX,
L_cache_per_b, # int32 [B], per-batch cache length
device,
dtype,
):
"""
Construct:
- seq_lens_bh[b, h] = L_cache_per_b[b]
- page_table[b, h, lp] giving physical page ids
- K_cache, V_cache filled for valid cached tokens
Physical layout:
physical_page_id = (b * H_kv + h) * N_LOGICAL_PAGES_MAX + lp
CACHE_SIZE = num_phys_pages * PAGE_SIZE
"""
assert L_cache_per_b.shape[0] == B
max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
assert (L_cache_per_b <= max_len).all()
seq_lens_bh = torch.empty((B, H_kv), dtype=torch.int32, device=device)
for b in range(B):
seq_lens_bh[b, :].fill_(L_cache_per_b[b])
num_phys_pages = B * H_kv * N_LOGICAL_PAGES_MAX
CACHE_SIZE = num_phys_pages * PAGE_SIZE
K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
page_table = torch.empty(
(B, H_kv, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
)
# assign unique physical pages per (b, h, lp)
phys_page = 0
for b in range(B):
for h in range(H_kv):
for lp in range(N_LOGICAL_PAGES_MAX):
page_table[b, h, lp] = phys_page
phys_page += 1
# fill cached tokens
g = torch.Generator(device=device).manual_seed(1234)
for b in range(B):
Lc = int(L_cache_per_b[b].item())
for h in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, h, lp].item())
idx = phys * PAGE_SIZE + off
K_cache[idx] = torch.randn(D, device=device, dtype=dtype, generator=g)
V_cache[idx] = torch.randn(D, device=device, dtype=dtype, generator=g)
return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
def materialize_kv_for_flash_mixed(
K_cache,
V_cache,
page_table,
L_cache_per_b, # [B]
k_append_raw, # [N, H_kv, D]
v_append_raw, # [N, H_kv, D]
cu_seqlens_qk, # [B+1]
H_kv,
PAGE_SIZE,
):
"""
Build (K_total, V_total, cu_seqlens_k) for flash_attn_varlen_func such that:
For each batch b:
seqlen_q[b] = L_app[b] = cu[b+1] - cu[b]
seqlen_k[b] = L_cache_per_b[b] + L_app[b]
Keys:
- first L_cache_per_b[b] positions from paged cache
- next L_app[b] positions from k_append_raw for that batch
"""
device = K_cache.device
dtype = K_cache.dtype
B = cu_seqlens_qk.numel() - 1
N, H_kv_raw, D = k_append_raw.shape
assert H_kv_raw == H_kv
# appended lengths
L_app = (cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).to(torch.int32) # [B]
seqlen_k = L_cache_per_b + L_app # [B]
cu_seqlens_k = torch.empty(B + 1, device=device, dtype=torch.int32)
cu_seqlens_k[0] = 0
total_k = int(seqlen_k.sum().item())
K_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
V_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
for b in range(B):
offset_k = int(cu_seqlens_k[b].item())
Lc = int(L_cache_per_b[b].item())
La = int(L_app[b].item())
q_start = int(cu_seqlens_qk[b].item())
# cache segment
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, g, lp].item())
idx = phys * PAGE_SIZE + off
K_total[offset_k + i, g] = K_cache[idx]
V_total[offset_k + i, g] = V_cache[idx]
# appended segment
if k_append_raw.numel() > 0:
for g in range(H_kv):
for j in range(La):
src = q_start + j
dst = offset_k + Lc + j
K_total[dst, g] = k_append_raw[src, g]
V_total[dst, g] = v_append_raw[src, g]
cu_seqlens_k[b + 1] = cu_seqlens_k[b] + (Lc + La)
return K_total, V_total, cu_seqlens_k
@pytest.mark.parametrize("workload", WORKLOADS, ids=lambda wl: wl.name)
def test_causal_sparse_varlen_with_cache(workload: Workload):
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
DEFAULT_PAGE_SIZE = 256
N_LOGICAL_PAGES_MAX = 256
L_cache_per_b = torch.as_tensor(
workload.cache_lens, device=device, dtype=torch.int32
)
K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
build_paged_cache_from_lengths(
B=workload.batch_size,
H_kv=workload.nk_heads,
D=workload.head_dim,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
L_cache_per_b=L_cache_per_b,
device=device,
dtype=dtype,
)
)
assert len(workload.append_lens) == workload.batch_size
cu = [0]
for L in workload.append_lens:
cu.append(cu[-1] + L)
cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
N = int(cu_seqlens_qk[-1].item())
q_raw = torch.randn(
N, workload.nq_heads, workload.head_dim, device=device, dtype=dtype
)
k_append_raw = torch.randn(
N, workload.nk_heads, workload.head_dim, device=device, dtype=dtype
)
v_append_raw = torch.randn_like(k_append_raw)
batch_mapping = torch.arange(workload.batch_size, device=device, dtype=torch.int32)
sm_scale = 1.0 / math.sqrt(workload.head_dim)
K_total, V_total, cu_seqlens_k = materialize_kv_for_flash_mixed(
K_cache=K_cache,
V_cache=V_cache,
page_table=page_table,
L_cache_per_b=L_cache_per_b,
k_append_raw=k_append_raw,
v_append_raw=v_append_raw,
cu_seqlens_qk=cu_seqlens_qk,
H_kv=workload.nk_heads,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
)
max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item())
max_seqlen_k_triton = seq_lens_bh.max().item()
out_triton = causal_sparse_varlen_with_cache(
q=q_raw,
k_cache=K_cache,
v_cache=V_cache,
k=k_append_raw,
v=v_append_raw,
seq_lens_bh=seq_lens_bh,
global_page_table=page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=cu_seqlens_qk,
HKV=workload.nk_heads,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
sm_scale=sm_scale,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_cache=max_seqlen_k_triton,
)
out_flash = flash_attn_varlen_func(
q=q_raw,
k=K_total,
v=V_total,
cu_seqlens_q=cu_seqlens_qk,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=0.0,
softmax_scale=sm_scale,
causal=True,
)
assert torch.allclose(out_triton, out_flash, rtol=1e-6, atol=3e-3)
max_diff = (out_triton - out_flash).abs().max().item()
logger.info(
f"[causal_sparse_varlen_with_cache: {workload.name}]: max abs diff={max_diff: .5f}"
)
def materialize_kv_cache_for_flash_decode(
K_cache,
V_cache,
page_table,
L_cache_per_b, # [B] int32
H_kv: int,
PAGE_SIZE: int,
):
"""
Build (K_flash, V_flash) suitable for flash_attn_with_kvcache, with shape:
(B, seqlen_cache_max, H_kv, D)
For each batch b:
- cache_seqlen[b] = L_cache_per_b[b]
- K_flash[b, :cache_seqlen[b], g] and V_flash[...] are filled from the paged KV cache.
- Tokens beyond cache_seqlen[b] (if any) are left as zeros and will be masked out
by flash_attn_with_kvcache via cache_seqlens.
"""
device = K_cache.device
dtype = K_cache.dtype
B = L_cache_per_b.shape[0]
D = K_cache.shape[1]
seqlen_cache_max = int(L_cache_per_b.max().item())
K_flash = torch.zeros((B, seqlen_cache_max, H_kv, D), device=device, dtype=dtype)
V_flash = torch.zeros_like(K_flash)
for b in range(B):
Lc = int(L_cache_per_b[b].item())
if Lc == 0:
continue
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, g, lp].item())
idx = phys * PAGE_SIZE + off
K_flash[b, i, g] = K_cache[idx]
V_flash[b, i, g] = V_cache[idx]
return K_flash, V_flash
@pytest.mark.parametrize("workload", WORKLOADS_DECODE, ids=lambda wl: wl.name)
def test_sparse_decode_attention(workload: Workload):
dtype = torch.float16
device = triton.runtime.driver.active.get_active_torch_device()
DEFAULT_PAGE_SIZE = 256
N_LOGICAL_PAGES_MAX = 256
# per-sequence cache lengths (all equal for WORKLOADS_DECODE)
L_cache_per_b = torch.as_tensor(
workload.cache_lens, device=device, dtype=torch.int32
)
# build paged KV cache used by the Triton kernel
K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
build_paged_cache_from_lengths(
B=workload.batch_size,
H_kv=workload.nk_heads,
D=workload.head_dim,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
L_cache_per_b=L_cache_per_b,
device=device,
dtype=dtype,
)
)
B = workload.batch_size
HQ = workload.nq_heads
HKV = workload.nk_heads
D = workload.head_dim
# Triton kernel expects q: [B, HQ, D]
q_triton = torch.randn(B, HQ, D, device=device, dtype=dtype)
batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
sm_scale = 1.0 / math.sqrt(D)
out_triton = head_sparse_decode_attention(
q=q_triton,
k=K_cache,
v=V_cache,
seq_lens_bh=seq_lens_bh,
global_page_table=page_table,
batch_mapping=batch_mapping,
HKV=HKV,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
sm_scale=sm_scale,
) # [B, HQ, D]
# materialize contiguous KV cache with shape [B, seqlen_cache_max, HKV, D]
K_flash, V_flash = materialize_kv_cache_for_flash_decode(
K_cache=K_cache,
V_cache=V_cache,
page_table=page_table,
L_cache_per_b=L_cache_per_b,
H_kv=HKV,
PAGE_SIZE=DEFAULT_PAGE_SIZE,
)
# flash_attn_with_kvcache expects q: [B, seqlen_q, HQ, D]
q_flash = q_triton.unsqueeze(1) # seqlen_q = 1
out_flash = flash_attn_with_kvcache(
q=q_flash,
k_cache=K_flash,
v_cache=V_flash,
cache_seqlens=L_cache_per_b,
softmax_scale=sm_scale,
causal=True,
).squeeze(1) # [B, 1, HQ, D]
assert torch.allclose(out_triton, out_flash, rtol=1e-6, atol=3e-3)
max_diff = (out_triton - out_flash).abs().max().item()
logger.info(
f"[head_sparse_decode_attention: {workload.name}]: max abs diff={max_diff: .5f}"
)
......@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
import os
from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any
......@@ -95,6 +96,7 @@ from vllm.v1.engine.llm_engine import LLMEngine
from vllm.v1.sample.logits_processor import LogitsProcessor
if TYPE_CHECKING:
from vllm.kvprune.integration.compression_params import CompressionParams
from vllm.v1.metrics.reader import Metric
logger = init_logger(__name__)
......@@ -184,6 +186,15 @@ class LLM:
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
kvprune_compression: If True, sets ``enforce_eager=True`` for the **v1**
engine only (no v1 CUDA graph capture). If ``None`` (default), read
``VLLM_KVPRUNE_COMPRESSION_DEFAULT`` (``"0"`` = allow v1 graphs;
``"1"`` = skip v1 graphs). This is independent of the compactor's
``LLMConfig.enforce_eager`` (see ``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` /
``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER``; default tries compactor graphs).
When True, v1's GPU KV pool defaults to **one** block (minimum allowed by
the scheduler) unless ``num_gpu_blocks_override`` is passed in ``**kwargs``
or ``VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS`` is set (``auto`` = profiled allocation).
enable_return_routed_experts: Whether to return routed experts.
disable_custom_all_reduce: See
[ParallelConfig][vllm.config.ParallelConfig].
......@@ -240,6 +251,7 @@ class LLM:
offload_prefetch_step: int = 1,
offload_params: set[str] | None = None,
enforce_eager: bool = False,
kvprune_compression: bool | None = None,
enable_return_routed_experts: bool = False,
disable_custom_all_reduce: bool = False,
hf_token: bool | str | None = None,
......@@ -339,6 +351,26 @@ class LLM:
"'examples/offline_inference/data_parallel.py'."
)
# v1 ``enforce_eager`` is independent of kvprune compactor ``LLMConfig.enforce_eager``.
if kvprune_compression is None:
_kvd = os.environ.get("VLLM_KVPRUNE_COMPRESSION_DEFAULT", "0").strip().lower()
kvprune_compression = _kvd in ("1", "true", "yes")
if kvprune_compression:
enforce_eager = True
# Reserve minimal v1 GPU KV so compactor can use the rest of VRAM. v1
# scheduler requires num_gpu_blocks >= 1; profiling would allocate a
# large pool from gpu_memory_utilization. Override:
# VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS unset -> 1 block (default)
# VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=auto -> profiled (no override)
# VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=<int> -> max(1, int)
if "num_gpu_blocks_override" not in kwargs:
_v1_kv = os.environ.get("VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS", "").strip()
if _v1_kv.lower() in ("auto", "profile"):
pass
elif not _v1_kv:
kwargs["num_gpu_blocks_override"] = 1
else:
kwargs["num_gpu_blocks_override"] = max(1, int(_v1_kv))
engine_args = EngineArgs(
model=model,
runner=runner,
......@@ -405,6 +437,9 @@ class LLM:
)
# Cache for __repr__ to avoid repeated collective_rpc calls
self._cached_repr: str | None = None
# Lazy compactor engine (``vllm.kvprune``) when :meth:`generate` uses compression.
self._kvprune_compactor_engine: Any = None
self._kvprune_compression_enabled = bool(kvprune_compression)
def get_tokenizer(self) -> TokenizerLike:
return self.llm_engine.get_tokenizer()
......@@ -446,6 +481,7 @@ class LLM:
lora_request: Sequence[LoRARequest] | LoRARequest | None = None,
priority: list[int] | None = None,
tokenization_kwargs: dict[str, Any] | None = None,
compression: "CompressionParams | Sequence[CompressionParams] | None" = None,
) -> list[RequestOutput]:
"""Generates the completions for the input prompts.
......@@ -473,6 +509,15 @@ class LLM:
of `prompts`, where each priority value corresponds to the prompt
at the same index.
tokenization_kwargs: Overrides for `tokenizer.encode`.
compression: Optional per-prompt KV compression (``vllm.kvprune``). If any
prompt has ``compression_ratio < 1.0``, the batch is run on the integrated
compactor engine with weights shared from this ``LLM``. Omit or use all
``compression_ratio >= 1`` to use the standard v1 engine only.
Use ``kvprune_compression=True`` or ``VLLM_KVPRUNE_COMPRESSION_DEFAULT=1``
so the v1 engine skips CUDA graph capture. Compactor decode graphs
default on (``VLLM_KVPRUNE_COMPACTOR_CUDA_GRAPH`` default ``1``) with
eager fallback if capture fails; set ``VLLM_KVPRUNE_COMPACTOR_ENFORCE_EAGER=1``
to skip compactor graph capture entirely.
Returns:
A list of `RequestOutput` objects containing the
......@@ -485,6 +530,41 @@ class LLM:
"Try passing `--runner generate` to use the model as a "
"generative model."
)
compression_eff = compression
if compression is None and getattr(self, "_kvprune_compression_enabled", False):
pc = self.llm_engine.vllm_config.parallel_config
if (
pc.tensor_parallel_size > 1
and pc.pipeline_parallel_size == 1
and pc.data_parallel_size == 1
):
from vllm.kvprune.integration.compression_params import CompressionParams
from vllm.kvprune.integration.compressed_generate import (
_normalize_prompt_list,
)
_plist = _normalize_prompt_list(prompts)
compression_eff = [
CompressionParams(compression_ratio=1.0) for _ in _plist
]
if compression_eff is not None:
from vllm.kvprune.integration.compressed_generate import (
try_compressed_generate,
)
compressed_out = try_compressed_generate(
self,
prompts,
sampling_params,
compression=compression_eff,
use_tqdm=use_tqdm,
lora_request=lora_request,
priority=priority,
tokenization_kwargs=tokenization_kwargs,
)
if compressed_out is not None:
return compressed_out
if sampling_params is None:
sampling_params = self.get_default_sampling_params()
......
......@@ -4,6 +4,60 @@
import importlib.util
import os
# KV-prune (compactor) shared-weight integration needs the v1 engine in-process
# (`worker.get_model()` in the parent). Upstream defaults to multiprocess workers
# (`VLLM_ENABLE_V1_MULTIPROCESSING=1`). If unset, default to in-process so
# `LLM.generate(..., compression=...)` works without requiring env to be set
# before `import vllm`. Set `VLLM_ENABLE_V1_MULTIPROCESSING=1` to restore
# multiprocess workers.
if "VLLM_ENABLE_V1_MULTIPROCESSING" not in os.environ:
os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
# In-process EngineCore (``VLLM_ENABLE_V1_MULTIPROCESSING=0``) shares the process with
# user code; ``import vllm`` already runs ``import torch`` below. TP workers are then
# created via multiprocessing. If we use ``fork`` after CUDA has been initialized in
# the parent, PyTorch raises ``Cannot re-initialize CUDA in forked subprocess``.
# ``_maybe_force_spawn()`` can miss this when CUDA is still uninitialized at the
# moment ``get_mp_context()`` runs, so default to ``spawn`` for worker processes unless
# the user set ``VLLM_WORKER_MULTIPROC_METHOD`` explicitly.
os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
# Tensor-parallel workers use NCCL, which queries **NVML for topology** (independent of
# PyTorch device counting). A faulty GPU on the host (e.g. ``nvidia-smi -L`` shows
# ``Unable to determine the device handle`` for one PCI address) often causes
# ``nvmlDeviceGetHandleByIndex(k) failed`` and then ``ncclCommInitRank`` errors.
# Mitigations: fix or isolate the bad GPU; or **before** ``import vllm`` restrict the
# container to healthy GPUs via UUID, e.g.
# export NVIDIA_VISIBLE_DEVICES=GPU-xxxx,GPU-yyyy,...
# (not only ``CUDA_VISIBLE_DEVICES=0,1,2,3``, which can still leave a dead GPU in
# NVML's enumeration). ``VLLM_KVPRUNE_NCCL_SAFE=1`` only tweaks P2P/IB, not NVML.
# For Docker, also consider ``--shm-size=10g`` or ``--ipc=host``.
if os.environ.get("VLLM_KVPRUNE_NCCL_SAFE", "").strip().lower() in (
"1",
"true",
"yes",
):
os.environ.setdefault("NCCL_P2P_DISABLE", "1")
os.environ.setdefault("NCCL_IB_DISABLE", "1")
# KV-prune: default ``LLM(kvprune_compression=None)`` to skip v1 CUDA graph capture
# (``enforce_eager=True`` on v1 only). Tests set ``VLLM_KVPRUNE_COMPRESSION_DEFAULT=0``
# in ``tests/conftest.py`` before importing vLLM.
os.environ.setdefault("VLLM_KVPRUNE_COMPRESSION_DEFAULT", "1")
# Before first compactor init: opt-in sleep(level=1)+wake_up to discard v1 KV (tests/conftest
# also set 0). Default off now that kvprune path can use num_gpu_blocks_override=1 for v1.
os.environ.setdefault("VLLM_KVPRUNE_RELEASE_V1_KV", "0")
# Optional: ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` (fa_triton / pdtriton / pdfa) or legacy
# ``VLLM_KVPRUNE_ATTENTION_BACKEND`` see ``vllm/kvprune/integration/config_adapter.py``.
# Optional: ``VLLM_KVPRUNE_SHARED_WEIGHT_GRAPH=1`` experimental compactor decode CUDA graphs.
#
# When ``LLM(..., kvprune_compression=True)`` (or default-on via
# ``VLLM_KVPRUNE_COMPRESSION_DEFAULT``), v1's ``num_gpu_blocks_override`` defaults
# to 1 in ``entrypoints/llm.py`` so the primary engine does not reserve a full
# profiled KV pool on the same GPU as the compactor. Use
# ``VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS=auto`` for profiled blocks, or a positive int.
def _get_torch_cuda_version():
"""Peripheral function to _maybe_set_cuda_compatibility_path().
......
......@@ -1030,6 +1030,21 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(
int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))
),
# KV-prune / compactor integration (see ``vllm/env_override.py``, ``vllm/kvprune/``).
"VLLM_KVPRUNE_ATTENTION_SCHEDULE": lambda: os.getenv(
"VLLM_KVPRUNE_ATTENTION_SCHEDULE", ""
),
"VLLM_KVPRUNE_ATTENTION_BACKEND": lambda: os.getenv(
"VLLM_KVPRUNE_ATTENTION_BACKEND", ""
),
"VLLM_KVPRUNE_COMPRESSION_DEFAULT": lambda: os.getenv(
"VLLM_KVPRUNE_COMPRESSION_DEFAULT", ""
),
"VLLM_KVPRUNE_RELEASE_V1_KV": lambda: os.getenv("VLLM_KVPRUNE_RELEASE_V1_KV", ""),
"VLLM_KVPRUNE_NCCL_SAFE": lambda: os.getenv("VLLM_KVPRUNE_NCCL_SAFE", ""),
"VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS": lambda: os.getenv(
"VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS", ""
),
"VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float(
os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1")
),
......@@ -1771,6 +1786,12 @@ def compile_factors() -> dict[str, object]:
"VLLM_ASSETS_CACHE_MODEL_CLEAN",
"VLLM_WORKER_MULTIPROC_METHOD",
"VLLM_ENABLE_V1_MULTIPROCESSING",
"VLLM_KVPRUNE_ATTENTION_SCHEDULE",
"VLLM_KVPRUNE_ATTENTION_BACKEND",
"VLLM_KVPRUNE_COMPRESSION_DEFAULT",
"VLLM_KVPRUNE_RELEASE_V1_KV",
"VLLM_KVPRUNE_NCCL_SAFE",
"VLLM_KVPRUNE_V1_NUM_GPU_BLOCKS",
"VLLM_V1_OUTPUT_PROC_CHUNK_SIZE",
"VLLM_CPU_KVCACHE_SPACE",
"VLLM_CPU_MOE_PREPACK",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV-cache pruning (compactor-style) under ``vllm.kvprune``.
Use the standard :class:`~vllm.LLM` and pass ``compression=`` to :meth:`~vllm.LLM.generate`
with :class:`CompressionParams` when any prompt needs ``compression_ratio < 1``. The compactor
``LLMEngine`` + ``PagedKVCache`` shares weights with vLLM (no second checkpoint).
Subpackages (``attention``, ``kv_cache``, ``compression``, …) implement the compactor
engine.
"""
from vllm.kvprune.compression.compression_config import CompressionMethod
from vllm.kvprune.integration import CompressionParams
__all__ = [
"CompressionMethod",
"CompressionParams",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sparse attention Triton kernels (varlen prefill, decode, compile helpers)."""
from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
__all__ = ["causal_sparse_varlen_with_cache"]
import argparse
import logging
import math
import torch
from vllm.kvprune.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
logger = logging.getLogger(__name__)
def build_mock_paged_cache_from_lengths(
L_cache_per_b: torch.Tensor,
HKV: int,
D: int,
PAGE_SIZE: int,
N_LOGICAL_PAGES_MAX: int,
device,
dtype,
):
B = len(L_cache_per_b)
max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
assert (L_cache_per_b <= max_len).all()
seq_lens_bh = torch.empty((B, HKV), dtype=torch.int32, device=device)
for b in range(B):
seq_lens_bh[b, :].fill_(L_cache_per_b[b])
num_phys_pages = B * HKV * N_LOGICAL_PAGES_MAX
CACHE_SIZE = num_phys_pages * PAGE_SIZE
K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
page_table = torch.empty(
(B, HKV, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
)
# assign unique physical pages per (b, h, lp)
phys_page = 0
for b in range(B):
for h in range(HKV):
for lp in range(N_LOGICAL_PAGES_MAX):
page_table[b, h, lp] = phys_page
phys_page += 1
for b in range(B):
Lc = int(L_cache_per_b[b].item())
for h in range(HKV):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, h, lp].item())
idx = phys * PAGE_SIZE + off
K_cache[idx] = torch.randn(D, device=device, dtype=dtype)
V_cache[idx] = torch.randn(D, device=device, dtype=dtype)
return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
def autotune_causal_sparse_varlen_with_cache(
*,
max_length: int = 16384,
HKV: int = 8,
HQ: int = 32,
D: int = 128,
PAGE_SIZE: int = 128,
device: str = "cuda",
dtype=torch.float16,
):
"""
Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
"""
import itertools
import tqdm
N_LOGICAL_PAGES_MAX = ((max_length + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
B = 4
# D must be a power of two (kernel requirement).
assert (D & (D - 1)) == 0
lengths_to_sweep = [0, 256]
i = 9
while (v := (1 << i)) < max_length:
lengths_to_sweep.append(v)
i += 1
combos = list(itertools.product(lengths_to_sweep, repeat=2))
logger.info(
"tuning kernels. this may take a few minutes, "
"but only needs to be run once per LLMConfig"
)
for cache_l, append_l in tqdm.tqdm(combos):
if cache_l + append_l == 0:
continue
L_cache_per_b = torch.tensor(
[cache_l] * B,
device=device,
dtype=torch.int32,
)
assert (L_cache_per_b <= PAGE_SIZE * N_LOGICAL_PAGES_MAX).all()
K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
build_mock_paged_cache_from_lengths(
L_cache_per_b=L_cache_per_b,
HKV=HKV,
D=D,
PAGE_SIZE=PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
device=device,
dtype=dtype,
)
)
L_app_list = [append_l] * B
cu = [0]
for L in L_app_list:
cu.append(cu[-1] + L)
cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
N = int(cu_seqlens_qk[-1].item())
max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
max_seqlen_k = seq_lens_bh.max().item()
q_raw = torch.randn(N, HQ, D, device=device, dtype=dtype)
k_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
v_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
# Identity batch mapping (local batch index == global)
batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
sm_scale = 1.0 / math.sqrt(D)
causal_sparse_varlen_with_cache(
q=q_raw,
k_cache=K_cache,
v_cache=V_cache,
k=k_append_raw,
v=v_append_raw,
seq_lens_bh=seq_lens_bh,
global_page_table=page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=cu_seqlens_qk,
HKV=HKV,
PAGE_SIZE=PAGE_SIZE,
sm_scale=sm_scale,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_cache=max_seqlen_k,
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Autotune Triton kernels. "
"Results are cached, so this should only need to be run once per configuration."
"This script doesn't need to be run, as the kernels will be autotuned at runtime"
"if no cached autotuning data exists. Running this before hand will prevent run-time"
"autotuning, which will accelerate compactor-vllm at inference time."
)
parser.add_argument(
"--max-length",
type=int,
default=16384,
help="Maximum total sequence length to consider.",
)
parser.add_argument(
"--HKV",
type=int,
default=8,
help="Number of KV heads.",
)
parser.add_argument(
"--HQ",
type=int,
default=32,
help="Number of query heads.",
)
parser.add_argument(
"--D",
type=int,
default=128,
help="Per-head hidden dimension (must be power of 2).",
)
parser.add_argument(
"--page-size",
type=int,
default=128,
help="Page size (tokens per physical page).",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu').",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
help="Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}.",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
help="Logging level.",
)
return parser.parse_args()
def _resolve_dtype(dtype_str: str):
s = dtype_str.lower()
if s in ("float16", "fp16", "half"):
return torch.float16
if s in ("bfloat16", "bf16"):
return torch.bfloat16
if s in ("float32", "fp32"):
return torch.float32
raise ValueError(f"Unsupported dtype: {dtype_str}")
def main():
args = _parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
dtype = _resolve_dtype(args.dtype)
logger.info(
"Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
"device=%s, dtype=%s",
args.max_length,
args.HKV,
args.HQ,
args.D,
args.page_size,
args.device,
dtype,
)
autotune_causal_sparse_varlen_with_cache(
max_length=args.max_length,
HKV=args.HKV,
HQ=args.HQ,
D=args.D,
PAGE_SIZE=args.page_size,
device=args.device,
dtype=dtype,
)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
)
main()
# SPDX-License-Identifier: Apache-2.0
"""FlashAttention paths over compactor paged KV (materialize + FA ops).
Used when :class:`~vllm.kvprune.config.engine_config.KvpruneAttentionSchedule`
selects FlashAttention for prefill and/or decode while KV **writes** remain on
Triton (``prefill_store_*``, ``decode_store_kv``).
**Why compactor-vllm looked fine but kvprune ``fa_triton`` + compression did not**
compactor-vllm ``layers/attention.py`` (prefill)::
use_flash_prefill = (backend == FLASH) or (COMPACTOR_TRITON and not do_compression)
if use_flash_prefill:
flash_attn_varlen_func(q, k, v, ...) # dense packed Q/K/V, one length per batch
elif COMPACTOR_TRITON:
causal_sparse_varlen_with_cache(..., seq_lens_bh=...) # paged KV, **per-(b,h)** lengths
So **with compression** (``do_compression``), compactor-vllm **never** runs FlashAttention on
paged top-K KV; it always uses Triton ``causal_sparse_varlen_with_cache``.
kvprune ``fa_triton`` (``FA_PREFILL_TRITON_DECODE``) keeps the intended split: **FA prefill**
+ **Triton decode**. For compressed prefill it calls :func:`flash_prefill_from_paged`, which
builds a dense ``[total_k, H_kv, D]`` tensor and calls ``flash_attn_varlen_func``. That layout
assumes **one cache prefix length per batch row shared by all KV heads** (same ``Lc`` for every
``g`` when copying from ``k_cache``). Top-K retention instead updates ``bh_lens`` with
**different** counts per head (``seq_lens_bh`` shape ``[B, HKV]``). Taking ``max(dim=1)``
(older code) used one ``Lc`` per batch but still filled ``K_total[offset+i, g]`` for every head
``g`` — heads with **shorter** real cache were **over-read**, corrupting attention.
We therefore **require** ``seq_lens_bh[b, :]`` to be constant in ``h`` for each ``b`` before
materializing for FA (see :func:`_require_uniform_kv_lens_per_batch_for_fa_materialize`). If your
retention policy yields unequal per-head lengths, use ``pdtriton`` (Triton prefill) for that
run, or disable compression while using ``fa_triton``.
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING
import torch
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
if TYPE_CHECKING:
pass
def _require_uniform_kv_lens_per_batch_for_fa_materialize(
seq_lens_bh: torch.Tensor, *, caller: str
) -> None:
"""FlashAttention varlen + dense ``[total_k, H_kv, D]`` layout needs one K length per batch."""
if seq_lens_bh.ndim != 2:
raise ValueError(f"{caller}: expected seq_lens_bh [B, HKV], got {seq_lens_bh.shape}")
row_min = seq_lens_bh.min(dim=1).values
row_max = seq_lens_bh.max(dim=1).values
if not bool((row_min == row_max).all().item()):
raise RuntimeError(
f"{caller}: FlashAttention materialization needs identical cached KV lengths "
"across KV heads for each batch row (seq_lens_bh[b, :] constant in h). "
f"Got per-batch min/max mismatch: min={row_min.tolist()} max={row_max.tolist()}. "
"Typical top-K compression uses different counts per head; compactor-vllm uses "
"Triton causal_sparse_varlen_with_cache in that case, not FA on materialized paged KV. "
"Use schedule ``pdtriton`` (Triton prefill + Triton decode), or disable compression "
"for this model run with ``fa_triton``."
)
def materialize_kv_for_flash_prefill(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor,
batch_mapping: torch.Tensor,
L_cache_per_b: torch.Tensor,
k_append: torch.Tensor,
v_append: torch.Tensor,
cu_seqlens_q: torch.Tensor,
H_kv: int,
PAGE_SIZE: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build packed K/V for :func:`flash_attn_varlen_func` (cache prefix + append)."""
device = k_cache.device
dtype = k_cache.dtype
B = cu_seqlens_q.numel() - 1
N, H_kv_raw, D = k_append.shape
assert H_kv_raw == H_kv
L_app = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(torch.int32)
seqlen_k = L_cache_per_b.to(torch.int32) + L_app
cu_seqlens_k = torch.empty(B + 1, device=device, dtype=torch.int32)
cu_seqlens_k[0] = 0
total_k = int(seqlen_k.sum().item())
K_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
V_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
for b in range(B):
offset_k = int(cu_seqlens_k[b].item())
Lc = int(L_cache_per_b[b].item())
La = int(L_app[b].item())
q_start = int(cu_seqlens_q[b].item())
b_true = int(batch_mapping[b].item())
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b_true, g, lp].item())
idx = phys * PAGE_SIZE + off
K_total[offset_k + i, g] = k_cache[idx]
V_total[offset_k + i, g] = v_cache[idx]
for g in range(H_kv):
for j in range(La):
src = q_start + j
dst = offset_k + Lc + j
K_total[dst, g] = k_append[src, g]
V_total[dst, g] = v_append[src, g]
cu_seqlens_k[b + 1] = cu_seqlens_k[b] + (Lc + La)
return K_total, V_total, cu_seqlens_k
def flash_prefill_from_paged(
q: torch.Tensor,
k_append: torch.Tensor,
v_append: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
*,
seq_lens_bh_before: torch.Tensor,
global_page_table: torch.Tensor,
batch_mapping: torch.Tensor,
cu_seqlens_q: torch.Tensor,
max_seqlen_q: int,
PAGE_SIZE: int,
HKV: int,
sm_scale: float | None,
) -> torch.Tensor:
"""Prefill attention via FlashAttention-2 varlen after materializing paged KV + append."""
_require_uniform_kv_lens_per_batch_for_fa_materialize(
seq_lens_bh_before, caller="flash_prefill_from_paged"
)
L_cache_per_b = seq_lens_bh_before.max(dim=1).values.to(torch.int32)
K_total, V_total, cu_seqlens_k = materialize_kv_for_flash_prefill(
k_cache,
v_cache,
global_page_table,
batch_mapping,
L_cache_per_b,
k_append,
v_append,
cu_seqlens_q,
HKV,
PAGE_SIZE,
)
max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item())
return flash_attn_varlen_func(
q,
K_total,
V_total,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=sm_scale if sm_scale is not None else None,
causal=True,
)
def materialize_kv_cache_for_flash_decode(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor,
batch_mapping: torch.Tensor,
L_cache_per_b: torch.Tensor,
H_kv: int,
PAGE_SIZE: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Dense ``[B, S, H_kv, D]`` cache for :func:`flash_attn_func` decode."""
device = k_cache.device
dtype = k_cache.dtype
B = L_cache_per_b.shape[0]
D = k_cache.shape[1]
seqlen_cache_max = int(L_cache_per_b.max().item())
K_flash = torch.zeros((B, seqlen_cache_max, H_kv, D), device=device, dtype=dtype)
V_flash = torch.zeros_like(K_flash)
for b in range(B):
Lc = int(L_cache_per_b[b].item())
if Lc == 0:
continue
b_true = int(batch_mapping[b].item())
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b_true, g, lp].item())
idx = phys * PAGE_SIZE + off
K_flash[b, i, g] = k_cache[idx]
V_flash[b, i, g] = v_cache[idx]
return K_flash, V_flash
def flash_decode_from_paged(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
*,
seq_lens_bh: torch.Tensor,
global_page_table: torch.Tensor,
batch_mapping: torch.Tensor,
PAGE_SIZE: int,
HKV: int,
sm_scale: float | None,
) -> torch.Tensor:
"""Decode step via FA: ``decode_store_kv`` has already appended the new K/V row."""
_require_uniform_kv_lens_per_batch_for_fa_materialize(
seq_lens_bh, caller="flash_decode_from_paged"
)
L_cache_per_b = seq_lens_bh.max(dim=1).values.to(torch.int32)
K_flash, V_flash = materialize_kv_cache_for_flash_decode(
k_cache,
v_cache,
global_page_table,
batch_mapping,
L_cache_per_b,
HKV,
PAGE_SIZE,
)
B, HQ, D = q.shape
q_b = q.unsqueeze(1)
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# One query position attends to all L keys already materialized in K/V (no causal mask).
out = flash_attn_func(
q_b,
K_flash,
V_flash,
softmax_scale=sm_scale,
causal=False,
)
return out.squeeze(1)
import functools
import math
import torch
import triton
import triton.language as tl
from vllm.kvprune.utils.triton_compat import (
autotune as triton_autotune,
maybe_set_allocator,
)
def head_sparse_decode_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_lens_bh: torch.Tensor,
global_page_table: torch.Tensor,
batch_mapping: torch.Tensor,
HKV: int,
PAGE_SIZE: int,
sm_scale: float = None,
key_split: int = None,
):
"""
Decode-time head-sparse attention over a paged KV cache.
This is a wrapper around the Triton decode kernel used during incremental
generation. For each batch, we read the cached keys
and values from a global paged KV buffer, apply causal attention with one
new query token, and return the attention output.
The KV cache is stored in a single global K/V tensor of shape
``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
(batch, kv_head, token_idx) is mapped to a physical row in the cache by:
1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
Grouped-query attention (GQA / MQA) is supported by passing more query
heads than KV heads (``HQ`` must be a multiple of ``HKV``).
Args:
:param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]``
containing the new decode tokens for each sequence in the launch batch.
:param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
backing buffer for all (batch, head) KV pages.
:param v: Global value cache of shape ``[CACHE_SIZE, D]``.
:param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
local batch index and KV head, the number of valid cached tokens
in the paged KV cache.
:param global_page_table: Tensor of shape
``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
``(true_batch_idx, kv_head, logical_page)`` to a physical page id
in the global cache.
:param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
index used by this call to the true batch row used to index
``global_page_table``.
:param HKV: Number of KV heads.
:param PAGE_SIZE: Number of tokens stored per physical KV page.
:param sm_scale: Optional scaling factor applied to the attention logits
before softmax. If ``None``, ``1 / sqrt(D)`` is used.
:param key_split: Optional number of splits along the key sequence length.
If > 1, the kernel will process the KV sequence in ``key_split``
chunks to reduce on-chip memory usage. If ``None`` or 0, a
heuristic is used.
Returns:
:return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
device and dtype as ``q``.
"""
with torch.cuda.device(q.device):
if q.ndim != 3:
assert q.ndim == 4
B, HQ, S, D = q.shape
assert S == 1, "head_sparse_decode_attention only supports q_len=1"
q = q.squeeze(-2)
elif q.ndim == 3:
B, HQ, D = q.shape
CACHE_SIZE = k.shape[0]
assert PAGE_SIZE % 32 == 0, "PAGE_SIZE must be divisible by 32"
GROUP_M = HQ // HKV
assert GROUP_M * HKV == HQ, "HQ must be divisible by H_kv"
FP8 = hasattr(torch, "float8_e5m2") and q.dtype == torch.float8_e5m2
seq_lens_bh = seq_lens_bh.to(torch.int32)
assert B <= 32767, "too many batches"
assert global_page_table.shape[1] == HKV
assert q.is_contiguous()
k = k.contiguous()
v = v.contiguous()
global_page_table = global_page_table.contiguous()
batch_mapping = batch_mapping.contiguous()
assert (D & (D - 1)) == 0, "D must be a power of 2"
N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
sm_scale = 1 / math.sqrt(D) if sm_scale is None else sm_scale
if key_split is None:
# round max_seq_len to the next power of two to maximize cache hits
key_split = num_splits_heuristic(
B * HKV,
max_seq_len=1 << int(seq_lens_bh.max()).bit_length(),
num_sms=torch.cuda.get_device_properties(
q.device
).multi_processor_count,
max_splits=12,
)
maybe_set_allocator(
lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
)
# stage 1 scratch
mid_o = torch.empty((B, key_split, HQ, D), device=q.device, dtype=q.dtype)
mid_lse = torch.empty((B, key_split, HQ), device=q.device, dtype=torch.float32)
# processes all queries for a KV head together
# pointers are lowercase, CONSTANTS are upper
grid1 = (B, HKV, key_split)
_varkv_stage1_groupM[grid1](
q=q,
k=k,
v=v,
mid_o=mid_o,
mid_lse=mid_lse,
page_table_bhl=global_page_table,
batch_mapping=batch_mapping,
seq_lens_bh=seq_lens_bh.contiguous(),
SM_SCALE=sm_scale,
B=B,
HKV=HKV,
HQ=HQ,
CACHE_SIZE=CACHE_SIZE,
STRIDE_LBS=mid_lse.stride(0),
STRIDE_LS=mid_lse.stride(1),
STRIDE_LH=mid_lse.stride(2),
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
D=D,
KEY_SPLIT=key_split,
GROUP_M=GROUP_M,
DTYPE=tl.float8e5
if FP8
else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
PAGE_SIZE=PAGE_SIZE,
)
if key_split == 1:
return mid_o.squeeze(1).contiguous()
# reduce partial results across splits
output = torch.empty_like(q)
grid2 = (B, HQ)
_varkv_stage2_reduce[grid2](
mid_o=mid_o,
mid_lse=mid_lse,
output=output,
STRIDE_LBS=mid_lse.stride(0),
STRIDE_LS=mid_lse.stride(1),
STRIDE_LH=mid_lse.stride(2),
STRIDE_OBS=output.stride(0),
STRIDE_OH=output.stride(1),
B=B,
HQ=HQ,
D=D, # type: ignore
KEY_SPLIT=key_split, # type: ignore
DTYPE=tl.float8e5
if FP8
else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
)
return output
# similar to flash attention split heuristic
@functools.lru_cache(maxsize=128)
def num_splits_heuristic(
total_mblocks: int,
max_seq_len: int,
num_sms: int,
max_splits: int,
) -> int:
# If we nearly fill SMs already, prefer 1 split
if total_mblocks >= 0.8 * num_sms or max_seq_len <= 1024:
return 1
eff = []
max_eff = 0.0
for s in range(1, min(max_splits, num_sms) + 1):
if (max_seq_len / s) <= 512:
break
n_waves = float(total_mblocks * s) / float(num_sms)
e = n_waves / math.ceil(n_waves) if n_waves > 0 else 0.0
eff.append(e)
max_eff = max(max_eff, e)
threshold = 0.75 * max_eff # if not split_min_hit else 0.9 * max_eff
for i, e in enumerate(eff, start=1):
if e >= threshold:
return i
return 1
def prune_invalid_configs(configs, _, **kwargs):
PAGE_SIZE = kwargs["PAGE_SIZE"]
return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) <= PAGE_SIZE]
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_N": BLOCK_N, "MIN_BLOCK_KV": MIN_BLOCK_KV, "WARPSPEC": ws},
num_warps=w,
num_stages=s,
)
for BLOCK_N in [32, 64, 128]
for MIN_BLOCK_KV in [8]
for s in [2, 3, 4]
for w in [4, 8]
for ws in [True, False]
],
key=[
"HKV",
"GROUP_M",
"D",
"PAGE_SIZE", # "B"
],
cache_results=True,
prune_configs_by={"early_config_prune": prune_invalid_configs},
)
@triton.jit
def _varkv_stage1_groupM(
q, # [B, HQ, D] contiguous
k, # GLOBAL cache: [CACHE_SIZE, D], contiguous
v, # GLOBAL cache: [CACHE_SIZE, D], contiguous
mid_o,
mid_lse,
page_table_bhl, # int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
batch_mapping, # int32 [B] maps local pid_b -> true batch index
seq_lens_bh, # int32 [B*H_kv] valid tokens per (b,h)
SM_SCALE,
B,
HKV,
HQ,
CACHE_SIZE, # CACHE_SIZE = N_PAGES * PAGE_SIZE
STRIDE_LBS,
STRIDE_LS,
STRIDE_LH,
# constexprs
N_LOGICAL_PAGES_MAX: tl.constexpr, # page table width per (b,h)
D: tl.constexpr,
KEY_SPLIT: tl.constexpr,
GROUP_M: tl.constexpr,
DTYPE: tl.constexpr,
BLOCK_N: tl.constexpr,
MIN_BLOCK_KV: tl.constexpr,
WARPSPEC: tl.constexpr,
PAGE_SIZE: tl.constexpr,
):
pid_b = tl.program_id(0) # batch
pid_kvh = tl.program_id(1) # kv head
pid_s = tl.program_id(2) # split
# valid length L for this (b,h)
bh_stride = HKV
L = tl.load(seq_lens_bh + pid_b * bh_stride + pid_kvh)
if L == 0:
return
tl.assume(L > 0)
# split sizing on logical token axis [0..L)
base = tl.cdiv(L, KEY_SPLIT)
per_split_len = tl.cdiv(base, MIN_BLOCK_KV) * MIN_BLOCK_KV
split_start = pid_s * per_split_len
split_end = tl.minimum(split_start + per_split_len, L)
# query heads mapped to this kv head
base_qh = pid_kvh * GROUP_M
GROUP_M_PAD: tl.constexpr = 16 if GROUP_M < 16 else GROUP_M
offs_m = tl.arange(0, GROUP_M_PAD)
mask_m = offs_m < GROUP_M
offs_d = tl.arange(0, D)
# load Q tile [M, D]
q_ptrs = q + (pid_b * HQ + base_qh + offs_m)[:, None] * D + offs_d[None, :]
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0).to(DTYPE) # [M, D]
# streaming softmax state per query
e_max = tl.zeros([GROUP_M_PAD], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([GROUP_M_PAD], dtype=tl.float32)
acc = tl.zeros([GROUP_M_PAD, D], dtype=tl.float32)
if split_end > split_start:
# logical pages covering [split_start, split_end)
lp0 = split_start // PAGE_SIZE
lp1 = tl.cdiv(split_end, PAGE_SIZE) # exclusive
mapped_b = tl.load(batch_mapping + pid_b)
tl.assume(mapped_b >= 0)
# page table base for this (b,h)
pt_stride = N_LOGICAL_PAGES_MAX
pt_base = (mapped_b * HKV + pid_kvh) * pt_stride
for lp in tl.range(lp0, lp1):
phys = tl.load(
page_table_bhl + pt_base + lp, cache_modifier=".cg"
) # physical page id
# bounds within the logical page
local_start = tl.where(lp == lp0, split_start - lp * PAGE_SIZE, 0)
local_end = tl.where(lp == (lp1 - 1), split_end - lp * PAGE_SIZE, PAGE_SIZE)
page_base = phys * PAGE_SIZE
page_base = tl.multiple_of(page_base, BLOCK_N)
for s in tl.range(local_start, local_end, BLOCK_N):
s = tl.multiple_of(s, MIN_BLOCK_KV)
offs_bn = tl.arange(0, BLOCK_N)
key_idx = page_base + s + offs_bn
k_ptrs = k + key_idx[:, None] * D + offs_d[None, :]
k_blk = tl.load(k_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
qk = tl.dot(q, k_blk.T) * SM_SCALE # [M, BN]
offs_n = s + tl.arange(0, BLOCK_N)
mask_n = offs_n < local_end
qk = tl.where(mask_n[None, :], qk, -float("inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max) # [M]
re_scale = tl.exp(e_max - n_e_max) # [M]
acc = acc * re_scale[:, None] # [M, D]
v_ptrs = v + key_idx[:, None] * D + offs_d[None, :]
v_blk = tl.load(v_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
p = tl.exp(qk - n_e_max[:, None]) # [M, BN]
acc = tl.dot(p.to(DTYPE), v_blk, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# write mid outputs [M, D] for this split
tmp = (acc / e_sum[:, None]).to(DTYPE)
row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
tl.store(mid_ptrs, tmp, mask=mask_m[:, None])
ml_ptrs = (
mid_lse
+ pid_b * STRIDE_LBS
+ pid_s * STRIDE_LS
+ (base_qh + offs_m) * STRIDE_LH
)
safe_sum = tl.where(mask_m, e_sum, 1.0)
tl.store(ml_ptrs, e_max + tl.log(safe_sum), mask=mask_m)
else:
# empty split
zero_md = tl.zeros([GROUP_M_PAD, D], dtype=DTYPE)
row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
tl.store(mid_ptrs, zero_md, mask=mask_m[:, None])
ml_ptrs = (
mid_lse
+ pid_b * STRIDE_LBS
+ pid_s * STRIDE_LS
+ (base_qh + offs_m) * STRIDE_LH
)
tl.store(ml_ptrs, -float("inf"), mask=mask_m)
@triton.jit
def _varkv_stage2_reduce(
mid_o,
mid_lse,
output,
STRIDE_LBS,
STRIDE_LS,
STRIDE_LH,
STRIDE_OBS,
STRIDE_OH,
B,
HQ,
D: tl.constexpr,
KEY_SPLIT: tl.constexpr,
DTYPE: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
offs_d = tl.arange(0, D)
# across split LSE combine
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([D], dtype=tl.float32)
for s in tl.range(KEY_SPLIT):
row_mid = pid_b * (KEY_SPLIT * HQ) + s * HQ + pid_h
tv = tl.load(mid_o + row_mid * D + offs_d).to(DTYPE)
tl_ptr = mid_lse + pid_b * STRIDE_LBS + s * STRIDE_LS + pid_h * STRIDE_LH
tlogic = tl.load(tl_ptr)
n_e_max = tl.maximum(e_max, tlogic)
old_scale = tl.exp(e_max - n_e_max)
acc = acc * old_scale + tl.exp(tlogic - n_e_max) * tv.to(tl.float32)
e_sum = e_sum * old_scale + tl.exp(tlogic - n_e_max)
e_max = n_e_max
o = (acc / e_sum).to(DTYPE)
o_ptr = output + pid_b * STRIDE_OBS + pid_h * STRIDE_OH + offs_d
tl.store(o_ptr, o)
import logging
import math
import torch
import triton
import triton.language as tl
from flash_attn.flash_attn_interface import flash_attn_varlen_func
from vllm.kvprune.utils.triton_compat import (
autotune as triton_autotune,
cuda_capability_geq,
maybe_set_allocator,
)
logger = logging.getLogger(__name__)
def _causal_appended_only_exact(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
*,
sm_scale: float,
max_seqlen_q: int,
) -> torch.Tensor:
"""Exact zero-prefix prefill attention over appended q/k/v only.
This is the mathematically correct subcase of
:func:`causal_sparse_varlen_with_cache` when there is no cached KV prefix.
It avoids the problematic Triton on-band appended branch while preserving
``pdtriton`` semantics for later cached-prefix steps. Use the same
``flash_attn_varlen_func`` path as the debug reference so this subcase is
numerically identical to the known-good result.
"""
return flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
softmax_scale=sm_scale,
causal=True,
)
def causal_sparse_varlen_with_cache(
q,
k,
v,
k_cache,
v_cache,
seq_lens_bh,
global_page_table,
batch_mapping,
cu_seqlens_q,
max_seqlen_q: int,
max_seqlen_k_cache: int,
HKV: int,
PAGE_SIZE: int,
sm_scale=None,
):
"""
Causal prefill attention over a paged KV cache plus a block of newly
appended tokens in a packed batch format.
This function wraps the Triton kernel
``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
a batch of variable-length sequences, where:
鈥?Past keys/values are stored in a paged global KV cache
(``k_cache``, ``v_cache``) with a (per-layer) page table.
鈥?New tokens for this step are given as K/V blocks
(``k``, ``v``), together with a packed query block ``q``.
鈥?The result is equivalent to applying causal attention over the
concatenation of:
[ cached KV prefix || (K_app, V_app) for this step ]
for each sequence in the batch.
Grouped-query attention (GQA / MQA) is supported by allowing more query
heads than KV heads: ``HQ`` must be divisible by ``HKV``.
Args:
:param q:
Query tensor of shape ``[N, HQ, D]`` (float16 / bfloat16/float32).
``N`` is the total number of new tokens across the batch
(i.e. ``N = sum_b seqlen_q[b]``), packed according to
``cu_seqlens_q``. ``HQ`` is the number of query heads, ``D`` the
head dimension (must be a power of two).
:param k:
New key tensor of shape ``[N, HKV, D]`` for the same tokens as
``q``. These are the K values appended to the cache for this
prefill step.
:param v:
New value tensor of shape ``[N, HKV, D]`` for the same tokens as
``q``.
:param k_cache:
Global key cache backing buffer of shape ``[CACHE_SIZE, D]``.
Keys for all cached tokens and heads are stored here; the mapping
from (batch, head, token index) to a row in this buffer is
given by ``global_page_table``.
:param v_cache:
Global value cache of shape ``[CACHE_SIZE, D]``. Must have the
same layout as ``k_cache`` (same ``CACHE_SIZE`` and ``D``).
:param seq_lens_bh:
Tensor of shape ``[B, HKV]`` (int32) giving, for each local batch
index and KV head, the number of cached tokens already present
in the paged KV cache before this prefill step.
:param global_page_table:
Tensor of shape ``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32)
mapping ``(true_batch_idx, kv_head, logical_page)`` to a physical
page id in the global KV cache. A physical page id `p` refers to
the slice:
``k_cache[p * PAGE_SIZE : (p + 1) * PAGE_SIZE]``.
:param batch_mapping:
Tensor of shape ``[B]`` (int16 / int32) mapping the local batch
index used in this kernel launch to the global batch index used
to index ``global_page_table``. This allows the same global cache
to be shared across multiple microbatches.
:param cu_seqlens_q:
Tensor of shape ``[B + 1]`` (int32) with cumulative sequence
lengths for the *new* tokens (q/k/v) in packed form. For batch
element ``b``:
``seqlen_q[b] = cu_seqlens_q[b + 1] - cu_seqlens_q[b]``.
The total number of tokens satisfies
``N = cu_seqlens_q[-1]``.
:param max_seqlen_q:
Maximum new query sequence length across the batch, i.e.
``max_b seqlen_q[b]``.
:param max_seqlen_k_cache:
Maximum cached sequence length across (batch, KV head), i.e.
``max_{b,h} seq_lens_bh[b, h]``.
:param HKV:
Number of KV heads. Must divide ``HQ``.
:param PAGE_SIZE:
Number of tokens stored per physical page in the paged KV cache.
``CACHE_SIZE`` must be divisible by ``PAGE_SIZE``.
:param sm_scale:
Optional scaling factor applied to the attention logits before
softmax. If ``None``, defaults to ``1.0 / sqrt(D)``.
:returns torch.Tensor:
Attention output of shape ``[N, HQ, D]``, with the same dtype and
device as ``q``. The output is laid out in the same packed
varlen format as the input queries, i.e. the first
``seqlen_q[0]`` rows correspond to batch 0, the next
``seqlen_q[1]`` rows to batch 1, etc.
"""
assert q.ndim == 3, "q should be [N, HQ, D]"
N, HQ, D = q.shape
assert (D & (D - 1)) == 0, "D must be power of two"
B = cu_seqlens_q.numel() - 1
assert B > 0
assert HQ % HKV == 0, "Number of query heads must divide number of keys heads"
if max_seqlen_k_cache == 0:
# Zero-prefix compressed prefill on DCU produced repeated-character output in
# the Triton on-band appended branch; use exact varlen FA for this subcase.
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
return _causal_appended_only_exact(
q,
k,
v,
cu_seqlens_q,
sm_scale=sm_scale,
max_seqlen_q=max_seqlen_q,
)
H_g = HQ // HKV
# view Q as [HKV, N, QUERY_GROUP_SIZE, D]
out = torch.empty_like(q)
q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
out = out.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
# K_app/V_app: [N, HKV, D] -> [HKV, N, D]
k_app = k.view(N, HKV, D).permute(1, 0, 2)
v_app = v.view(N, HKV, D).permute(1, 0, 2)
q = q.contiguous()
out = out.contiguous()
k_app = k_app.contiguous()
v_app = v_app.contiguous()
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=q.device)
seq_lens_bh = seq_lens_bh.to(dtype=torch.int32, device=q.device)
batch_mapping = batch_mapping.to(dtype=torch.int16, device=q.device)
N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
CACHE_SIZE = k_cache.shape[0]
assert v_cache.shape[0] == CACHE_SIZE
assert k_cache.shape[1] == D and v_cache.shape[1] == D
assert PAGE_SIZE > 0 and CACHE_SIZE % PAGE_SIZE == 0
k_cache = k_cache.contiguous()
v_cache = v_cache.contiguous()
global_page_table = global_page_table.contiguous()
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# strides for Q [G, N, QUERY_GROUP_SIZE, D]
STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
STRIDE_KC, STRIDE_VC = k_cache.stride(0), v_cache.stride(0)
# [G, N, D]
STRIDE_KA_G, STRIDE_KA_N, STRIDE_KA_D = k_app.stride()
STRIDE_VA_G, STRIDE_VA_N, STRIDE_VA_D = v_app.stride()
# OUT [G, N, QUERY_GROUP_SIZE, D]
STRIDE_OUT_G, STRIDE_OUT_N, STRIDE_OUT_H, STRIDE_OUT_D = out.stride()
# launch grid
maybe_set_allocator(
lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
)
assert STRIDE_KA_D == STRIDE_VA_D == STRIDE_Q_D == STRIDE_OUT_D == 1, (
"final dimension must be contiguous"
)
def grid(META):
return HKV, B, triton.cdiv(max_seqlen_q, META["BLOCK_M"])
# Autotune key must reflect the **total** K length seen by the kernel:
# cached prefix + appended tokens from the current prefill chunk.
#
# Using only `max_seqlen_k_cache` is wrong for the first compressed prefill
# step in `pdtriton`: the cache prefix is 0, but the kernel actually attends
# over the entire appended prompt (`seq_len_append`). On DCU this can cause
# Triton to autotune/select a kernel as if K==1 while executing on a long K,
# which has been observed to produce incorrect outputs. We still clamp to 1
# to avoid `next_power_of_2(0)`.
_k_max_autotune = max(int(max_seqlen_k_cache) + int(max_seqlen_q), 1)
AUTOTUNE_MAX_Q_LEN = triton.next_power_of_2(max_seqlen_q)
AUTOTUNE_MAX_K_LEN = triton.next_power_of_2(_k_max_autotune)
_causal_head_sparse_varlen_with_cache[grid](
Q=q,
K_cache=k_cache,
V_cache=v_cache,
K_app=k_app,
V_app=v_app,
cu_seqlens_qk=cu_seqlens_q,
seq_lens_bh=seq_lens_bh,
page_table=global_page_table,
batch_mapping=batch_mapping,
OUT=out,
HKV=HKV,
QUERY_GROUP_SIZE=H_g,
PAGE_SIZE=PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
STRIDE_Q_G=STRIDE_Q_G,
STRIDE_Q_N=STRIDE_Q_N,
STRIDE_Q_H=STRIDE_Q_H,
STRIDE_KC=STRIDE_KC,
STRIDE_VC=STRIDE_VC,
STRIDE_KA_G=STRIDE_KA_G,
STRIDE_KA_N=STRIDE_KA_N,
STRIDE_VA_G=STRIDE_VA_G,
STRIDE_VA_N=STRIDE_VA_N,
STRIDE_OUT_G=STRIDE_OUT_G,
STRIDE_OUT_N=STRIDE_OUT_N,
STRIDE_OUT_H=STRIDE_OUT_H,
sm_scale=sm_scale,
D=D,
AUTOTUNE_MAX_Q_LEN=AUTOTUNE_MAX_Q_LEN,
AUTOTUNE_MAX_K_LEN=AUTOTUNE_MAX_K_LEN,
)
# permute breaks contiguity; view() requires a single contiguous span.
return out.permute(1, 0, 2, 3).reshape(N, HQ, D)
autotune_configs_cc9 = [
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=16, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=4, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
),
]
autotune_configs_cc8 = [
triton.Config(
{"BLOCK_N": BN, "BLOCK_M": BM, "WARPSPEC": True}, num_warps=w, num_stages=s
)
for BN in [16, 32]
for BM in [64]
for w in [4, 8]
for s in [2, 3]
]
def prune_invalid_configs(configs, _, **kwargs):
return [
conf
for conf in configs
if not (conf.kwargs.get("BLOCK_N") == 32 and conf.kwargs.get("num_stages") == 4)
]
def get_autotune_configs():
if cuda_capability_geq(9, 0):
return autotune_configs_cc9
else:
return autotune_configs_cc8
@triton_autotune(
configs=get_autotune_configs(),
key=[
"HKV",
"QUERY_GROUP_SIZE",
"D",
"PAGE_SIZE",
"AUTOTUNE_MAX_K_LEN",
"AUTOTUNE_MAX_Q_LEN",
],
cache_results=True,
)
@triton.jit
def _causal_head_sparse_varlen_with_cache(
Q, # [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
K_cache,
V_cache, # [CACHE_SIZE, D]
K_app,
V_app, # [HKV, N, D]
cu_seqlens_qk, # [B+1]
seq_lens_bh, # [B, HKV]
page_table, # [B_total, HKV, N_LOGICAL_PAGES_MAX]
batch_mapping, # [B], maps local b -> global batch index
OUT, # [HKV, N, QUERY_GROUP_SIZE, D]
#
HKV: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
PAGE_SIZE: tl.constexpr,
N_LOGICAL_PAGES_MAX,
STRIDE_Q_G,
STRIDE_Q_N,
STRIDE_Q_H,
STRIDE_KC,
STRIDE_VC,
STRIDE_KA_G,
STRIDE_KA_N,
STRIDE_VA_G,
STRIDE_VA_N,
STRIDE_OUT_G,
STRIDE_OUT_N,
STRIDE_OUT_H,
sm_scale,
#
D: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
WARPSPEC: tl.constexpr,
AUTOTUNE_MAX_Q_LEN: tl.constexpr, # used for autotune key
AUTOTUNE_MAX_K_LEN: tl.constexpr, # used for autotune key
):
TOTAL_N_QUERIES: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
pid_g = tl.program_id(0) # kv_head id in [0, HKV)
pid_b = tl.program_id(1) # batch id
pid_m = tl.program_id(2) # query-tile id within batch
# batch segment [qb, qe) in N
off_b = tl.load(cu_seqlens_qk + pid_b)
off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
seq_len_append = off_b1 - off_b
q_start = off_b + pid_m * BLOCK_M
q_end = tl.minimum(q_start + BLOCK_M, off_b1)
# number of queries in this tile for this batch
M = q_end - q_start
if M <= 0:
return
# cached length for (b, kv_head=pid_g)
L_cache = tl.load(seq_lens_bh + pid_b * HKV + pid_g)
# row indices flattened over [QUERY_GROUP_SIZE, M]
offs_row = tl.arange(0, TOTAL_N_QUERIES)
row_m = offs_row % BLOCK_M
row_h = offs_row // BLOCK_M
# valid rows: only those with row_m < M
row_mask = row_m < M
# global query index per row
q_idx = q_start + row_m
offs_d = tl.arange(0, D)
# Q tile: [TOTAL_N_QUERIES, D]
# Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
q_ptrs = (
Q
+ pid_g * STRIDE_Q_G
+ q_idx[:, None] * STRIDE_Q_N
+ row_h[:, None] * STRIDE_Q_H
+ offs_d[None, :]
)
q = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
e_max = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32)
acc = tl.zeros([TOTAL_N_QUERIES, D], dtype=tl.float32)
offs_block_n = tl.arange(0, BLOCK_N)
# Convert natural-log softmax scale into log2 domain for exp2-based updates.
# Use the full log2(e) constant; this is mathematically equivalent to exp and
# not the source of the zero-prefix bug, but avoids avoidable rounding loss.
qk_scale = sm_scale * 1.4426950408889634
# 1) attend over cachee K/V
if L_cache > 0:
# map local (b) to global batch index
mapped_b = tl.load(batch_mapping + pid_b)
pt_base = (mapped_b * HKV + pid_g) * N_LOGICAL_PAGES_MAX
# iterate logical pages
num_lp = tl.cdiv(L_cache, PAGE_SIZE)
for lp in tl.range(0, num_lp):
# can overflow in 32 bits so upcast
phys = tl.load(page_table + pt_base + lp).to(tl.int64)
page_start = phys * PAGE_SIZE
# how many valid tokens in this page for this (b,g)
remain = L_cache - lp * PAGE_SIZE
page_len = tl.minimum(PAGE_SIZE, remain)
# iterate over this page in BLOCK_N chunks
for ks in tl.range(0, page_len, BLOCK_N):
offs_n = ks + offs_block_n
mask_n = offs_n < page_len
key_idx = page_start + offs_n
k_ptrs = K_cache + key_idx[:, None] * STRIDE_KC + offs_d[None, :]
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
qk = tl.dot(q, k.T) * qk_scale # [TOTAL_N_QUERIES, BN]
qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
# softmax update
cur_max = tl.max(qk, 1)
n_e_max = tl.maximum(e_max, cur_max)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max[:, None])
v_ptrs = V_cache + key_idx[:, None] * STRIDE_VC + offs_d[None, :]
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
acc = acc * re_scale[:, None]
acc = tl.dot(p.to(v.dtype), v, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# 2) attend over appended K_app/V_app (causal)
# appended tokens for batch b are in [off_b, off_b1)
# query tile is [q_start, q_end)
# for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
if q_end > off_b:
# exactly one appended token
if seq_len_append == 1:
ka_ptrs = K_app + pid_g * STRIDE_KA_G + off_b * STRIDE_KA_N + offs_d
k = tl.load(ka_ptrs) # [D]
qk = tl.sum(q * k[None, :], 1) * qk_scale
qk = tl.where(row_mask, qk, -1.0e6)
n_e_max = tl.maximum(e_max, qk)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max)
va_ptrs = V_app + pid_g * STRIDE_VA_G + off_b * STRIDE_VA_N + offs_d
v = tl.load(va_ptrs) # [D]
acc = acc * re_scale[:, None] + p[:, None] * v[None, :]
e_sum = e_sum * re_scale + p
else:
# off-band: k in [off_b, q_start)
# for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
# so no causal mask needed.
off_band_start = off_b
off_band_end = q_start
if off_band_end > off_band_start:
for ks in tl.range(off_band_start, off_band_end, BLOCK_N):
offs_n = ks + offs_block_n
mask_n = offs_n < off_band_end
ka_ptrs = (
K_app
+ pid_g * STRIDE_KA_G
+ offs_n[:, None] * STRIDE_KA_N
+ offs_d[None, :]
)
k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
qk = tl.dot(q, k.T) * qk_scale
qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
cur_max = tl.max(qk, 1)
n_e_max = tl.maximum(e_max, cur_max)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max[:, None])
va_ptrs = (
V_app
+ pid_g * STRIDE_VA_G
+ offs_n[:, None] * STRIDE_VA_N
+ offs_d[None, :]
)
v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
acc = acc * re_scale[:, None]
acc = tl.dot(p.to(v.dtype), v, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# on-band remaining k
on_band_start = tl.maximum(q_start, off_b)
if on_band_start < q_end:
for ks in tl.range(on_band_start, q_end, BLOCK_N):
offs_n = ks + tl.arange(0, BLOCK_N)
mask_n = offs_n < q_end
ka_ptrs = (
K_app
+ pid_g * STRIDE_KA_G
+ offs_n[:, None] * STRIDE_KA_N
+ offs_d[None, :]
)
k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
qk = tl.dot(q, k.T) * qk_scale
# DCU/ROCm: using a single fused boolean expression here can lead
# to early query rows in the tile behaving as if they could attend
# to later appended keys in the same on-band block. That shows up
# as token-0 output deviating from V[0] while the last token in the
# batch remains almost exact. Apply the three masks explicitly.
#
# Use local positions within the current query tile for the causal
# relation: all off-band keys (< q_start) were already handled
# above, so the on-band block only needs a lower-triangular mask
# relative to q_start.
qk = tl.where(row_mask[:, None], qk, -1.0e6)
qk = tl.where(mask_n[None, :], qk, -1.0e6)
local_q = row_m
local_k = offs_n - q_start
caus_mask = local_k[None, :] <= local_q[:, None]
qk = tl.where(caus_mask, qk, -1.0e6)
cur_max = tl.max(qk, 1)
n_e_max = tl.maximum(e_max, cur_max)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max[:, None])
va_ptrs = (
V_app
+ pid_g * STRIDE_VA_G
+ offs_n[:, None] * STRIDE_VA_N
+ offs_d[None, :]
)
v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
acc = acc * re_scale[:, None]
acc = tl.dot(p.to(v.dtype), v, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# 3) write outputs
o = (acc / e_sum[:, None]).to(q.dtype)
out_ptrs = (
OUT
+ pid_g * STRIDE_OUT_G
+ q_idx[:, None] * STRIDE_OUT_N
+ row_h[:, None] * STRIDE_OUT_H
+ offs_d[None, :]
)
tl.store(out_ptrs, o, mask=row_mask[:, None])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark helpers for kv-prune / compactor kernels.
Upstream snapshot (``compactor-vllm/src/compactor_vllm/benchmark``) contained **only**
an empty ``__init__.py`` — no additional ``.py`` scripts. Those files are merged here
as-is; there is nothing else to list under that directory in upstream.
Use :data:`BENCHMARK_REGISTRY` to register microbenchmarks or CLI entrypoints you
add under ``vllm.kvprune.benchmark``.
"""
from __future__ import annotations
from typing import Any, Callable
# Files copied from upstream ``compactor_vllm/benchmark/`` (relative to that dir).
UPSTREAM_BENCHMARK_FILES: tuple[str, ...] = ("__init__.py",)
# Optional: name -> benchmark callable or import path string (e.g. "mymod:main").
# Populated when you add real benchmarks beside this package.
BENCHMARK_REGISTRY: dict[str, Callable[..., Any] | str] = {}
def list_upstream_benchmark_files() -> tuple[str, ...]:
"""Return the list of filenames that existed in upstream ``benchmark/``."""
return UPSTREAM_BENCHMARK_FILES
def register_benchmark(name: str, target: Callable[..., Any] | str) -> None:
"""Register a benchmark by name (callable or ``"module:attr"`` import path)."""
BENCHMARK_REGISTRY[name] = target
def iter_registered_benchmarks() -> list[tuple[str, Callable[..., Any] | str]]:
"""Return ``(name, target)`` pairs from :data:`BENCHMARK_REGISTRY`."""
return list(BENCHMARK_REGISTRY.items())
__all__ = [
"BENCHMARK_REGISTRY",
"UPSTREAM_BENCHMARK_FILES",
"iter_registered_benchmarks",
"list_upstream_benchmark_files",
"register_benchmark",
]
from vllm.kvprune.compression.common import (
BaseCompressionMethod,
NoCompression,
)
from vllm.kvprune.compression.criticalkv import CriticalAdaKVCompression
from vllm.kvprune.compression.compactor import CompactorCompression
from vllm.kvprune.compression.compression_config import (
BatchCompressionParams,
CompressionMethod,
SequenceCompressionParams,
)
from vllm.kvprune.compression.snapkv import SnapKVCompression
COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = {
CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression,
CompressionMethod.COMPACTOR: CompactorCompression,
CompressionMethod.SNAPKV: SnapKVCompression,
CompressionMethod.NONE: NoCompression,
}
def apply_prerope_compression(q, k, v, context):
method = context.compression_context.compression_method
return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context)
def apply_postrope_compression(q, k, v, prerope_scores, context):
method = context.compression_context.compression_method
return COMPRESSION_REGISTRY[method].post_rope_scoring(
q, k, v, prerope_scores, context=context
)
__all__ = [
"apply_prerope_compression",
"apply_postrope_compression",
"CompressionMethod",
"BatchCompressionParams",
"SequenceCompressionParams",
"COMPRESSION_REGISTRY"
]
from abc import ABC, abstractmethod
import os
from typing import Optional
import torch
from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv
class BaseCompressionMethod(ABC):
"""
Abstract interface for KV cache compression methods.
A compression method is implemented as a pair of optional scoring phases
that run before and after rotary position embedding (RoPE) is applied:
1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
- refine / reweight the pre-RoPE scores, or
- compute potentially position-aware.
Concrete subclasses are expected to implement both
static methods and return a single tensor of scores (or ``None`` if the
phase is a no-op), which the caller can then feed into the shared
“scores → top-k indices → KV extraction” pipeline.
"""
@staticmethod
@abstractmethod
def pre_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
"""
Compute per-token importance scores from pre-RoPE queries/keys.
Args:
:param q:
Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
vllm.kvprune.utils.context.Context object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
A tensor of scores (e.g. per-token, per-head importance values)
to be passed to ``post_rope_scoring`` or directly into the
top-k selection step. If this phase is a no-op, implementations
should return ``None``. Shape ``[total_tokens, HKV]```.
"""
pass
@staticmethod
@abstractmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
"""
Compute or refine importance scores from post-RoPE queries/keys.
This method is called after rotary embeddings have been applied. It can
optionally use both the post-RoPE Q/K and any scores produced by
``pre_rope_scoring`` to produce final scores used for token selection.
Common patterns include:
* Using ``pre_rope_scores`` as a base signal and applying a
position-aware correction.
* Only computing scores that depend on absolute or relative positions.
* Simply passing through ``pre_rope_scores`` unchanged.
Args:
:param q:
Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param pre_rope_scores:
Optional scores returned by ``pre_rope_scoring``. May be
``None`` if the pre-RoPE phase returned None.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
vllm.kvprune.utils.context.Context object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
Final importance scores to be consumed by the compression
pipeline (for top-k token selection). If this phase is a
no-op, implementations may return ``pre_rope_scores``. If
None is returned, no compression will be applied.
"""
pass
class NoCompression(BaseCompressionMethod):
"""
Trivial compression method that disables KV cache compression.
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
return None
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
return pre_rope_scores
def extract_and_store_top_kv(
scores: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_k_len: int,
top_k: int,
H: int,
new_keys: torch.Tensor, # [N_total, H, D]
new_vals: torch.Tensor, # [N_total, H, D]
num_tokens_to_retain: torch.Tensor, # [B] int32
page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE: int,
PAD_TO_PAGE_SIZE: bool = True,
K_TILE: int = 16,
padding: float = -float("inf"),
):
"""helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
# per_head: per-head highest-scoring remaining tokens for page padding.
# global_scan: legacy global ranking order, padded by scanning forward in-kernel.
padding_mode = os.environ.get(
"VLLM_KVPRUNE_PADDING_MODE", "per_head"
).strip().lower()
max_pairs_per_batch = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).to(
device=num_tokens_to_retain.device, dtype=num_tokens_to_retain.dtype
) * H
num_tokens_to_retain = torch.minimum(num_tokens_to_retain, max_pairs_per_batch)
indices_topk, candidate_counts = scores_to_retain_indices(
scores,
cu_seqlens_k=cu_seqlens_k,
max_k_len=max_k_len,
top_k=top_k,
H=H,
num_tokens_to_retain=num_tokens_to_retain,
page_size=PAGE_SIZE,
padding_mode=padding_mode,
padding=padding,
)
prefill_store_topk_kv(
new_keys=new_keys,
new_vals=new_vals,
indices_topk=indices_topk,
candidate_counts=candidate_counts,
num_tokens_to_retain=num_tokens_to_retain,
page_table=page_table,
batch_mapping=batch_mapping,
bh_lens=bh_lens,
k_cache=k_cache,
v_cache=v_cache,
cu_seqlens_k=cu_seqlens_k,
PAGE_SIZE=PAGE_SIZE,
PAD_TO_PAGE_SIZE=PAD_TO_PAGE_SIZE,
K_TILE=K_TILE,
)
def scores_to_retain_indices(
scores: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_k_len: int,
top_k: int,
H: int,
num_tokens_to_retain: torch.Tensor,
page_size: int,
padding_mode: str = "per_head",
padding: float = -float("inf"),
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Build candidate token-head indices for compression writes.
For each batch element, this helper returns:
1. a prefix of the true global top-k ``(token, head)`` pairs, and
2. a suffix of additional padding candidates according to ``padding_mode``:
- ``per_head``: choose each head's highest-scoring remaining tokens.
- ``global_scan``: keep the legacy global ranking order and let the
store kernel scan forward until it finds enough entries for that head.
The page-alignment requirement comes from the paged KV cache, but the
padding candidates themselves do not need to be discovered inside the
Triton store kernel. Choosing them here avoids the older "scan the global
candidate list until you stumble across enough entries for this head"
behavior, which could distort the retained set even though the page-table
/ reclaim logic only cares about the final per-head counts.
Args:
:param scores:
Tensor of shape ``[N_total, HKV]`` containing scores for each
(token, head) pair in packed varlen format.
:param cu_seqlens_k:
Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
lengths for each batch element. The total number of tokens
satisfies ``N_total = cu_seqlens_k[-1]``.
:param max_k_len:
Maximum key sequence length across the batch (i.e.
``max_b seqlen_k[b]``). Used to allocate the padded buffer.
:param top_k:
Kept for API compatibility with the caller. The retained prefix is
determined by ``num_tokens_to_retain``; the tail is built from
per-head padding needs.
:param H:
Number of key heads; must match ``scores.shape[1]``.
:param num_tokens_to_retain:
The true number of token-head pairs to keep for each batch element
before page padding.
:param page_size:
Page size of the KV cache. Determines how many extra candidates
are needed per head to reach page alignment.
:param padding_mode:
``per_head`` for per-head optimal padding candidates, or
``global_scan`` for the legacy "scan the global ranking" behavior.
:param padding:
Kept for backward compatibility; no longer used.
Returns:
A tuple ``(indices, counts)`` where:
- ``indices`` is ``[B, MAX_SEL]`` int64, containing global flattened
``token * H + head`` indices.
- ``counts`` is ``[B]`` int32, the number of valid candidates for each
batch row inside ``indices``.
"""
del max_k_len, top_k, padding
B, device = cu_seqlens_k.numel() - 1, scores.device
row_indices: list[torch.Tensor] = []
candidate_counts = torch.zeros(B, dtype=torch.int32, device=device)
if padding_mode not in ("per_head", "global_scan"):
raise ValueError(
"Unsupported VLLM_KVPRUNE_PADDING_MODE. "
f"Expected 'per_head' or 'global_scan', got {padding_mode!r}."
)
for b in range(B):
s = int(cu_seqlens_k[b].item())
e = int(cu_seqlens_k[b + 1].item())
seq_len = e - s
total_pairs = seq_len * H
keep = min(int(num_tokens_to_retain[b].item()), total_pairs)
if total_pairs == 0 or keep == 0:
row_indices.append(torch.empty(0, dtype=torch.int64, device=device))
continue
seq_scores = scores[s:e, :] # [L, H]
flat_scores = seq_scores.reshape(-1)
if padding_mode == "global_scan":
row = torch.argsort(flat_scores, dim=0, descending=True)
else:
prefix = torch.topk(
flat_scores, k=keep, dim=0, largest=True, sorted=True
).indices
selected_flat = torch.zeros(total_pairs, dtype=torch.bool, device=device)
selected_flat[prefix] = True
selected_mask = selected_flat.view(seq_len, H)
head_counts = torch.bincount(prefix % H, minlength=H)
need_per_head = (page_size - (head_counts % page_size)) % page_size
max_extra_per_head = seq_len - head_counts
need_per_head = torch.minimum(need_per_head, max_extra_per_head)
tails: list[torch.Tensor] = []
for h in range(H):
need = int(need_per_head[h].item())
if need <= 0:
continue
rem_scores_h = seq_scores[:, h].masked_fill(
selected_mask[:, h], -torch.inf
)
tail_tok = torch.topk(
rem_scores_h, k=need, dim=0, largest=True, sorted=True
).indices
tails.append(tail_tok * H + h)
if tails:
row = torch.cat([prefix, *tails], dim=0)
else:
row = prefix
row_indices.append(row + s * H)
candidate_counts[b] = int(row.numel())
max_sel = max((int(x.numel()) for x in row_indices), default=0)
if max_sel == 0:
return (
torch.zeros((B, 1), dtype=torch.int64, device=device),
candidate_counts,
)
indices = torch.zeros((B, max_sel), dtype=torch.int64, device=device)
for b, row in enumerate(row_indices):
if row.numel():
indices[b, : row.numel()] = row
return indices, candidate_counts
"""
Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
全局 z-score、blending 与首尾 sink pad)。
非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
杠杆分路径使用 batched ``torch.matmul``;在 transpose 与进入线性代数前对张量 ``.contiguous()``。
CUDA 上用 ``cholesky_solve``;在 HIP/ROCm 上对小的 sketch 维 ``k`` 用 ``linalg.inv(G+λI) @ X^T``
代替 ``cholesky_solve``,避开 rocBLAS TRSM 的 launch-bounds 告警与部分栈上的不稳定行为。
非因果 PyTorch 回退同理。
"""
from __future__ import annotations
import math
from typing import List, Optional
import torch
import triton
import triton.language as tl
from transformers.models.llama.modeling_llama import repeat_kv
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
def resolve_kvpress_compactor_blending(compression_context) -> float:
"""与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
if compression_context is None:
return 0.35
b = getattr(compression_context, "compactor_blending", None)
if b is not None:
return float(b)
cr = getattr(compression_context, "compression_ratio", None)
if cr is not None:
return float(cr)
return 0.35
class CompactorCompression(BaseCompressionMethod):
"""与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
chunk_size: int = 256
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
return maybe_execute_in_stream(
kvpress_leverage_scores_packed,
k,
context.cu_seqlens_q,
compression_context,
STORE_STREAM=context.STORE_STREAM,
)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
blending = resolve_kvpress_compactor_blending(compression_context)
return maybe_execute_in_stream(
kvpress_compactor_post_rope,
q,
k,
v,
context.cu_seqlens_q,
pre_rope_scores,
compression_context,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
blending=float(blending),
STORE_STREAM=context.STORE_STREAM,
)
# ---------------------------------------------------------------------------
# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
# ---------------------------------------------------------------------------
def chol_with_jitter(
G: torch.Tensor, jitter: float = 0.0, max_tries: int = 5
) -> torch.Tensor:
identity = torch.eye(G.shape[-1], device=G.device, dtype=G.dtype)
cur = float(jitter)
for _ in range(max_tries):
L, info = torch.linalg.cholesky_ex(
(G + cur * identity).contiguous(), upper=False
)
if bool((info == 0).all()):
return L
cur = max(1e-8, (1e-2 if cur == 0.0 else 10.0 * cur))
raise RuntimeError(f"Cholesky failed after {max_tries} tries.")
def compute_leverage_scores_mid(
key_states: torch.Tensor, sketch_dimension: int
) -> torch.Tensor:
"""
与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
返回 ``[L, H]``(未 z-score)。
维序与 kvpress 的 ``(B, H, S, D)`` 对齐;batched GEMM + ``.contiguous()`` 以利于后端库。
"""
d, k = key_states.shape[-1], sketch_dimension
device, dtype = key_states.device, key_states.dtype
H = key_states.shape[1]
Phi = torch.randn(1, H, d, k, device=device, dtype=dtype) * (1.0 / math.sqrt(k))
X0 = key_states.transpose(0, 1).unsqueeze(0).contiguous()
X = (X0 - X0.mean(dim=-2, keepdim=True)).contiguous()
Phi = Phi.contiguous()
X = torch.matmul(X, Phi).to(torch.float32).contiguous()
XT = X.transpose(-2, -1).contiguous()
G = torch.matmul(XT, X)
G_sym = 0.5 * (G + G.transpose(-2, -1)).contiguous()
# HIP: avoid batched cholesky_solve -> rocBLAS TRSM (launch_bounds noise / edge cases).
# k is sketch_dim (typically modest); inv is O(k^3) but batched over heads.
if torch.version.hip is not None:
kk = G_sym.shape[-1]
eye = torch.eye(
kk, device=G_sym.device, dtype=G_sym.dtype, requires_grad=False
)
G_reg = G_sym + 1e-2 * eye
inv_Xt = torch.linalg.inv(G_reg) @ XT
else:
L_mat = chol_with_jitter(G_sym, jitter=1e-2, max_tries=5)
inv_Xt = torch.cholesky_solve(XT, L_mat, upper=False)
inv_Xt_T = inv_Xt.transpose(-2, -1).contiguous()
scores = (X * inv_Xt_T).sum(dim=-1).clamp_min(0)
return scores.squeeze(0).transpose(0, 1).contiguous()
def kvpress_leverage_scores_packed(
key_states: torch.Tensor,
cu_seqlens: torch.Tensor,
compression_ctx,
) -> torch.Tensor:
device = key_states.device
N, Hkv, _D = key_states.shape
sketch_dim = int(getattr(compression_ctx, "sketch_dimension", 48))
sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
mids_flat: list[torch.Tensor] = []
mid_ranges: list[tuple[int, int, int]] = []
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
k_mid = key_states[mid_start:mid_end, :, :].contiguous()
raw = compute_leverage_scores_mid(k_mid, sketch_dim)
mids_flat.append(raw.reshape(-1))
mid_ranges.append((mid_start, mid_end, Hkv))
if not mids_flat:
return out
flat = torch.cat(mids_flat, dim=0)
z = _zscore_flat_f32_global(flat)
offset = 0
for (mid_start, mid_end, _Hkv), r in zip(mid_ranges, mids_flat):
n = r.numel()
seg = z[offset : offset + n].view(mid_end - mid_start, Hkv)
out[mid_start:mid_end, :] = seg
offset += n
return out
# ---------------------------------------------------------------------------
# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
# ---------------------------------------------------------------------------
def _non_causal_chunked_attn_pytorch(
q: torch.Tensor, k: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""参考实现:与 kvpress 逐算子一致。"""
assert chunk_size > 0 and q.shape == k.shape
L, H, d = q.shape
B = 1
q = q.permute(1, 0, 2).unsqueeze(0).contiguous()
k = k.permute(1, 0, 2).unsqueeze(0).contiguous()
_B, H, S, _d = k.shape
S_pad = math.ceil(S / chunk_size) * chunk_size
pad_len = S_pad - S
if pad_len > 0:
q_padded = torch.cat(
[q, torch.zeros(B, H, pad_len, d, device=q.device, dtype=q.dtype)], dim=2
)
k_padded = torch.cat(
[k, torch.zeros(B, H, pad_len, d, device=k.device, dtype=k.dtype)], dim=2
)
last_chunk_start = (S // chunk_size) * chunk_size
in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
else:
q_padded, k_padded = q, k
last_chunk_start = ((S - 1) // chunk_size) * chunk_size
in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
num_chunks = S_pad // chunk_size
q_chunks = q_padded.contiguous().view(B, H, num_chunks, chunk_size, d)
k_chunks = k_padded.contiguous().view(B, H, num_chunks, chunk_size, d)
dots = torch.matmul(
q_chunks, k_chunks.transpose(-2, -1).contiguous()
)
dots[:, :, -1].masked_fill_(query_mask.unsqueeze(-1), 0)
dots[:, :, -1].masked_fill_(key_mask.unsqueeze(-2), -1e-9)
attn = torch.softmax(dots.to(torch.float32), dim=-1)
out = attn.sum(dim=-2).view(B, H, S_pad)[..., :S]
return out.squeeze(0).transpose(0, 1).contiguous()
@triton.jit
def _non_causal_chunk_row_kernel(
Q_ptr,
K_ptr,
Out_ptr,
stride_qh,
stride_qs,
stride_qd,
stride_kh,
stride_ks,
stride_kd,
stride_oh,
stride_os,
S,
S_pad,
num_chunks,
CHUNK_SIZE: tl.constexpr,
D: tl.constexpr,
BLOCK_D: tl.constexpr,
ND: tl.constexpr,
):
"""
每个 program:一个 head、一个 chunk、一条 query 行。
对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
"""
h = tl.program_id(0)
c = tl.program_id(1)
iq = tl.program_id(2)
g_i = c * CHUNK_SIZE + iq
offs_j = tl.arange(0, CHUNK_SIZE)
logits = tl.zeros([CHUNK_SIZE], dtype=tl.float32)
for db in range(ND):
offs_d = tl.arange(0, BLOCK_D) + db * BLOCK_D
mask_d = offs_d < D
q_off = (
h * stride_qh + g_i * stride_qs + offs_d * stride_qd
)
qd = tl.load(Q_ptr + q_off, mask=mask_d, other=0.0).to(tl.float32)
g_j = c * CHUNK_SIZE + offs_j
k_row_off = h * stride_kh + g_j[:, None] * stride_ks + offs_d[None, :] * stride_kd
kj = tl.load(K_ptr + k_row_off, mask=mask_d[None, :], other=0.0).to(tl.float32)
logits += tl.sum(qd[None, :] * kj, axis=1)
row_invalid = g_i >= S
g_j_all = c * CHUNK_SIZE + offs_j
col_invalid = g_j_all >= S
logits = tl.where(row_invalid, tl.zeros([CHUNK_SIZE], dtype=tl.float32), logits)
logits = tl.where(
row_invalid,
logits,
tl.where(col_invalid, tl.full([CHUNK_SIZE], -1e-9, dtype=tl.float32), logits),
)
m = tl.max(logits)
logits = logits - m
exp_v = tl.exp(logits)
denom = tl.sum(exp_v)
p = exp_v / denom
out_base = h * stride_oh + g_j_all * stride_os
tl.atomic_add(Out_ptr + out_base, p, mask=g_j_all < S)
def _non_causal_chunked_attn_triton(
q: torch.Tensor, k: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
assert q.is_cuda and k.is_cuda and q.shape == k.shape
L, H, d = q.shape
assert chunk_size > 0
S_pad = math.ceil(L / chunk_size) * chunk_size
pad_len = S_pad - L
if pad_len > 0:
zq = torch.zeros(
pad_len, H, d, device=q.device, dtype=q.dtype, requires_grad=False
)
zk = torch.zeros(
pad_len, H, d, device=k.device, dtype=k.dtype, requires_grad=False
)
q = torch.cat([q, zq], dim=0)
k = torch.cat([k, zk], dim=0)
Q = q.transpose(0, 1).contiguous().to(dtype=torch.float32)
K = k.transpose(0, 1).contiguous().to(dtype=torch.float32)
num_chunks = S_pad // chunk_size
out_acc = torch.zeros(H, S_pad, device=q.device, dtype=torch.float32)
S = int(L)
grid = (H, num_chunks, chunk_size)
BLOCK_D = 32 if d <= 128 else 64
ND = (d + BLOCK_D - 1) // BLOCK_D
_non_causal_chunk_row_kernel[grid](
Q,
K,
out_acc,
Q.stride(0),
Q.stride(1),
Q.stride(2),
K.stride(0),
K.stride(1),
K.stride(2),
out_acc.stride(0),
out_acc.stride(1),
S,
S_pad,
int(num_chunks),
CHUNK_SIZE=chunk_size,
D=d,
BLOCK_D=BLOCK_D,
ND=ND,
num_warps=4,
)
return out_acc[:, :S].transpose(0, 1).contiguous()
def non_causal_chunked_attn(q: torch.Tensor, k: torch.Tensor, chunk_size: int) -> torch.Tensor:
"""q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
if q.is_cuda and k.is_cuda:
return _non_causal_chunked_attn_triton(q, k, chunk_size)
return _non_causal_chunked_attn_pytorch(q, k, chunk_size)
# ---------------------------------------------------------------------------
# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
# ---------------------------------------------------------------------------
@triton.jit
def _mul_vnorm_avgpool3_kernel(
A_ptr,
V_ptr,
OUT_ptr,
stride_al,
stride_ah,
stride_vl,
stride_vh,
stride_vd,
stride_ol,
stride_oh,
L,
D: tl.constexpr,
):
"""Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
l = tl.program_id(0)
h = tl.program_id(1)
offs = tl.arange(0, D)
pos_m1 = l - 1
inb_m1 = (pos_m1 >= 0) & (pos_m1 < L)
ps_m1 = tl.where(inb_m1, pos_m1, 0)
a_m1 = tl.load(
A_ptr + ps_m1 * stride_al + h * stride_ah,
mask=inb_m1,
other=0.0,
).to(tl.float32)
v_m1 = tl.load(
V_ptr + ps_m1 * stride_vl + h * stride_vh + offs * stride_vd,
mask=inb_m1,
other=0.0,
).to(tl.float32)
s_m1 = tl.where(inb_m1, a_m1 * tl.sqrt(tl.sum(v_m1 * v_m1)), 0.0)
inb_0 = (l >= 0) & (l < L)
ps0 = tl.where(inb_0, l, 0)
a0 = tl.load(
A_ptr + ps0 * stride_al + h * stride_ah,
mask=inb_0,
other=0.0,
).to(tl.float32)
v0 = tl.load(
V_ptr + ps0 * stride_vl + h * stride_vh + offs * stride_vd,
mask=inb_0,
other=0.0,
).to(tl.float32)
s_0 = tl.where(inb_0, a0 * tl.sqrt(tl.sum(v0 * v0)), 0.0)
pos_p1 = l + 1
inb_p1 = (pos_p1 >= 0) & (pos_p1 < L)
ps_p1 = tl.where(inb_p1, pos_p1, 0)
a_p1 = tl.load(
A_ptr + ps_p1 * stride_al + h * stride_ah,
mask=inb_p1,
other=0.0,
).to(tl.float32)
v_p1 = tl.load(
V_ptr + ps_p1 * stride_vl + h * stride_vh + offs * stride_vd,
mask=inb_p1,
other=0.0,
).to(tl.float32)
s_p1 = tl.where(inb_p1, a_p1 * tl.sqrt(tl.sum(v_p1 * v_p1)), 0.0)
out = (s_m1 + s_0 + s_p1) * (1.0 / 3.0)
tl.store(OUT_ptr + l * stride_ol + h * stride_oh, out)
def _mul_vnorm_avgpool3_fused(
a: torch.Tensor, v: torch.Tensor, out: torch.Tensor | None = None
) -> torch.Tensor:
assert a.dim() == 2 and v.dim() == 3 and a.shape[0] == v.shape[0] and a.shape[1] == v.shape[1]
L, H, D = v.shape
a = a.contiguous()
v = v.contiguous()
if a.dtype != torch.float32:
a = a.float()
if out is None:
out = torch.empty((L, H), device=v.device, dtype=torch.float32)
if L == 0 or H == 0:
return out
grid = (L, H)
_mul_vnorm_avgpool3_kernel[grid](
a,
v,
out,
a.stride(0),
a.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
out.stride(0),
out.stride(1),
L,
D=D,
num_warps=4,
)
return out
def _maybe_mul_vnorm_avgpool3_fused(a: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
if not a.is_cuda or not v.is_cuda:
import torch.nn.functional as F
s = a * v.norm(dim=-1)
return (
F.avg_pool1d(s.transpose(0, 1).unsqueeze(0), kernel_size=3, padding=1, stride=1)
.squeeze(0)
.transpose(0, 1)
)
return _mul_vnorm_avgpool3_fused(a, v)
@triton.jit
def _zscore_elem_1d_kernel(
X_ptr,
OUT_ptr,
n,
mean,
inv_std,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n
x = tl.load(X_ptr + offs, mask=mask, other=0.0)
tl.store(OUT_ptr + offs, (x - mean) * inv_std, mask=mask)
def _zscore_flat_f32_global(x: torch.Tensor) -> torch.Tensor:
"""
与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
"""
if x.numel() == 0:
return x
mu = x.mean()
sig = x.std().clamp_min(1e-6)
inv = 1.0 / sig
if not x.is_cuda:
return (x - mu) * inv
x = x.contiguous()
out = torch.empty_like(x)
n = x.numel()
BLOCK = 1024
grid = (triton.cdiv(n, BLOCK),)
_zscore_elem_1d_kernel[grid](
x,
out,
n,
float(mu.item()),
float(inv.item()),
BLOCK=BLOCK,
num_warps=4,
)
return out
def _attn_scores_kvpress_middle(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
sink_start: int,
sink_end: int,
chunk_size: int,
do_zscore: bool = True,
) -> torch.Tensor:
"""仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
N, HQ, D = q.shape
Hkv = k.shape[1]
G = HQ // Hkv
device = q.device
attn_out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
parts: list[torch.Tensor] = []
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
q_m = q[mid_start:mid_end, :, :].contiguous()
k_m = k[mid_start:mid_end, :, :].contiguous()
v_m = v[mid_start:mid_end, :, :].contiguous()
# HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
k_4d = k_m.unsqueeze(0).transpose(1, 2).contiguous() # [1, Hkv, Lm, D]
k_rep = repeat_kv(k_4d, G)[0].transpose(0, 1).contiguous() # [Lm, HQ, D]
A = non_causal_chunked_attn(q_m, k_rep, chunk_size)
Lm, HQa = A.shape
assert HQa == HQ
A = A.view(Lm, Hkv, G).mean(dim=-1)
scores = _maybe_mul_vnorm_avgpool3_fused(A, v_m)
parts.append(scores.reshape(-1))
if not parts:
return attn_out
flat_a = torch.cat(parts, dim=0)
if do_zscore:
z_a = _zscore_flat_f32_global(flat_a)
else:
z_a = flat_a
offset = 0
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
n = (mid_end - mid_start) * Hkv
attn_out[mid_start:mid_end, :] = z_a[offset : offset + n].view(
mid_end - mid_start, Hkv
)
offset += n
return attn_out
def non_causal_attn_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qk: torch.Tensor,
max_seqlen_qk: int,
chunk_size: int,
sm_scale: float = None,
normalize: bool = True,
context_lens: Optional[List[int]] = None,
protected_first_tokens: Optional[List[int]] = None,
protected_last_tokens: Optional[List[int]] = None,
*,
accum_scores: torch.Tensor = None,
accum_blending: float = None,
) -> torch.Tensor:
"""
与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
"""
del sm_scale, max_seqlen_qk
sink_start, sink_end = 8, 4
out = _attn_scores_kvpress_middle(
q,
k,
v,
cu_seqlens_qk,
sink_start,
sink_end,
chunk_size,
do_zscore=normalize,
)
if accum_scores is not None:
w = 0.5 if accum_blending is None else float(accum_blending)
out = out + w * accum_scores.to(device=out.device, dtype=out.dtype)
if protected_first_tokens is not None and protected_last_tokens is not None and context_lens:
start = 0
for first, last, Lc in zip(
protected_first_tokens, protected_last_tokens, context_lens
):
out[start : start + int(first)].fill_(torch.inf)
out[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
start += int(Lc)
return out
def kvpress_compactor_post_rope(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
pre_rope_scores: torch.Tensor,
compression_ctx,
max_seqlen_q: int,
chunk_size: int,
blending: float,
) -> torch.Tensor:
del max_seqlen_q
Hkv = k.shape[1]
device = q.device
sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
context_lens: Optional[List[int]] = getattr(
compression_ctx, "context_lens", None
)
protected_first: Optional[List[int]] = getattr(
compression_ctx, "protected_first_tokens", None
)
protected_last: Optional[List[int]] = getattr(
compression_ctx, "protected_last_tokens", None
)
attn_out = _attn_scores_kvpress_middle(
q, k, v, cu_seqlens, sink_start, sink_end, chunk_size
)
lev = pre_rope_scores.to(device=device, dtype=torch.float32)
blended = torch.zeros_like(lev)
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
blended[mid_start:mid_end, :] = (
blending * lev[mid_start:mid_end, :] + attn_out[mid_start:mid_end, :]
)
pad_val = blended.max()
if not torch.isfinite(pad_val) or pad_val == 0:
pad_val = torch.tensor(1.0, device=device, dtype=torch.float32)
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if left_keep > 0:
blended[k_beg:mid_start, :] = pad_val
if right_keep > 0:
blended[mid_end:k_end, :] = pad_val
if protected_first is not None and protected_last is not None and context_lens:
start = 0
for first, last, Lc in zip(
protected_first, protected_last, context_lens
):
blended[start : start + int(first)].fill_(torch.inf)
blended[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
start += int(Lc)
return blended
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment