Commit 2b7160c6 authored by chenzk's avatar chenzk
Browse files

vllm kvprune:v1.0.0

parent fa718036
"""Tensor-parallel collectives for kvprune (match vLLM TP process group when embedded)."""
from __future__ import annotations
import torch.distributed as dist
def tensor_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor:
"""All-reduce across tensor-parallel ranks (in-place on ``tensor`` when possible).
When vLLM :mod:`vllm.distributed.parallel_state` is initialized (e.g. kvprune
runs inside a vLLM GPU worker), uses the same TP NCCL group as the main model
(:func:`~vllm.distributed.communication_op.tensor_model_parallel_all_reduce`).
vLLM's TP :meth:`~vllm.distributed.parallel_state.GroupCoordinator.all_reduce`
is **out-of-place** and returns a new tensor. Call sites such as
:class:`~vllm.kvprune.layers.linear.RowParallelLinear` historically invoked
``tensor_parallel_all_reduce(y)`` without using the return value, which left
``y`` as the **unreduced** per-rank partial output under TP>1 — wrong activations,
wrong logits, and garbage tokens. We copy the reduced result back into ``tensor``
so existing call sites remain correct.
Standalone kvprune subprocesses only have the default process group (world ==
``tensor_parallel_size``); in that case we fall back to :func:`torch.distributed.all_reduce`
on the default group.
"""
if not dist.is_initialized() or dist.get_world_size() <= 1:
return tensor
try:
from vllm.distributed.parallel_state import model_parallel_is_initialized
if model_parallel_is_initialized():
from vllm.distributed.communication_op import (
tensor_model_parallel_all_reduce as vllm_tp_all_reduce,
)
reduced = vllm_tp_all_reduce(tensor)
if reduced is not tensor:
# vLLM TP all_reduce is out-of-place: `reduced` holds the cross-rank sum.
# Call sites ignore the return value and expect `tensor` to be updated — we
# MUST materialize the reduced values here or TP>1 keeps per-rank partials
# (RowParallel / VocabParallel outputs stay wrong without this copy).
tensor.copy_(reduced)
return tensor
except Exception:
pass
dist.all_reduce(tensor)
return tensor
"""Tensor-parallel helpers for kvprune when embedded in a vLLM worker."""
from __future__ import annotations
import torch.distributed as dist
def tensor_parallel_rank_for_sharding() -> int:
"""Rank within the tensor-parallel group (matches vLLM weight shards when embedded).
Falls back to :func:`torch.distributed.get_rank` when vLLM parallel state is
unavailable (standalone kvprune with only the default process group).
"""
try:
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
return int(get_tensor_model_parallel_rank())
except Exception:
if dist.is_initialized():
return int(dist.get_rank())
return 0
def tensor_parallel_world_size_for_sharding() -> int:
"""World size of the tensor-parallel group."""
try:
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_world_size,
)
return int(get_tensor_model_parallel_world_size())
except Exception:
if dist.is_initialized():
return int(dist.get_world_size())
return 1
def kv_heads_shard_divisor() -> int:
"""Return world size used to shard KV heads (TP group when vLLM is loaded)."""
return tensor_parallel_world_size_for_sharding()
from __future__ import annotations
import inspect
import os
from typing import Any, Callable, Mapping
import torch
from vllm.logger import init_logger
logger = init_logger(__name__)
_cache_results_warned = False
def _ensure_kvprune_triton_cache_dir() -> None:
"""Set a stable Triton cache dir for kvprune kernels unless already set."""
if os.environ.get("TRITON_CACHE_DIR"):
return
cache_root = os.environ.get("VLLM_CACHE_ROOT", os.path.expanduser("~/.cache/vllm"))
triton_cache = os.path.join(cache_root, "kvprune_triton_cache")
os.makedirs(triton_cache, exist_ok=True)
os.environ["TRITON_CACHE_DIR"] = triton_cache
_ensure_kvprune_triton_cache_dir()
def _filter_kwargs_for_callable(
fn: Callable[..., Any], kwargs: Mapping[str, Any]
) -> dict[str, Any]:
try:
params = inspect.signature(fn).parameters
except (TypeError, ValueError):
return dict(kwargs)
return {k: v for k, v in kwargs.items() if k in params}
def autotune(*, configs, key, **kwargs):
"""
Compatibility wrapper around `triton.autotune`.
Some Triton builds (e.g., custom vendor builds) may not support newer
keyword arguments like `cache_results`. This wrapper filters unsupported
kwargs based on the runtime `triton.autotune` signature.
"""
import triton
filtered = _filter_kwargs_for_callable(triton.autotune, kwargs)
global _cache_results_warned
if (
not _cache_results_warned
and "cache_results" in kwargs
and "cache_results" not in filtered
):
logger.warning_once(
"Current Triton build does not accept cache_results in triton.autotune; "
"kvprune autotune results may not persist across runs."
)
_cache_results_warned = True
return triton.autotune(configs=configs, key=key, **filtered)
def maybe_set_allocator(alloc_fn: Callable[[int, int, int | None], Any]) -> bool:
"""
Call `triton.set_allocator(alloc_fn)` if present; otherwise no-op.
Returns True if the allocator was set.
"""
import triton
setter = getattr(triton, "set_allocator", None)
if setter is None:
return False
setter(alloc_fn)
return True
def cuda_capability_geq(major: int, minor: int = 0, device: int | None = None) -> bool:
"""
Host-side CUDA capability check that works even when `tl.target_info` is absent.
"""
if not torch.cuda.is_available():
return False
if device is None:
try:
device = torch.cuda.current_device()
except Exception:
device = 0
cap = torch.cuda.get_device_capability(device)
return cap >= (major, minor)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
KV-cache pruning (compactor-style) under ``vllm.kvprune``.
Use the standard :class:`~vllm.LLM` and pass ``compression=`` to :meth:`~vllm.LLM.generate`
with :class:`CompressionParams` when any prompt needs ``compression_ratio < 1``. The compactor
``LLMEngine`` + ``PagedKVCache`` shares weights with vLLM (no second checkpoint).
Subpackages (``attention``, ``kv_cache``, ``compression``, …) implement the compactor
engine.
"""
from vllm.kvprune.compression.compression_config import CompressionMethod
from vllm.kvprune.integration import CompressionParams
__all__ = [
"CompressionMethod",
"CompressionParams",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Sparse attention Triton kernels (varlen prefill, decode, compile helpers)."""
from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
__all__ = ["causal_sparse_varlen_with_cache"]
import argparse
import logging
import math
import torch
from vllm.kvprune.attention.sparse_varlen_kernel import (
causal_sparse_varlen_with_cache,
)
logger = logging.getLogger(__name__)
def build_mock_paged_cache_from_lengths(
L_cache_per_b: torch.Tensor,
HKV: int,
D: int,
PAGE_SIZE: int,
N_LOGICAL_PAGES_MAX: int,
device,
dtype,
):
B = len(L_cache_per_b)
max_len = PAGE_SIZE * N_LOGICAL_PAGES_MAX
assert (L_cache_per_b <= max_len).all()
seq_lens_bh = torch.empty((B, HKV), dtype=torch.int32, device=device)
for b in range(B):
seq_lens_bh[b, :].fill_(L_cache_per_b[b])
num_phys_pages = B * HKV * N_LOGICAL_PAGES_MAX
CACHE_SIZE = num_phys_pages * PAGE_SIZE
K_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
V_cache = torch.zeros((CACHE_SIZE, D), device=device, dtype=dtype)
page_table = torch.empty(
(B, HKV, N_LOGICAL_PAGES_MAX), device=device, dtype=torch.int32
)
# assign unique physical pages per (b, h, lp)
phys_page = 0
for b in range(B):
for h in range(HKV):
for lp in range(N_LOGICAL_PAGES_MAX):
page_table[b, h, lp] = phys_page
phys_page += 1
for b in range(B):
Lc = int(L_cache_per_b[b].item())
for h in range(HKV):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b, h, lp].item())
idx = phys * PAGE_SIZE + off
K_cache[idx] = torch.randn(D, device=device, dtype=dtype)
V_cache[idx] = torch.randn(D, device=device, dtype=dtype)
return K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE
def autotune_causal_sparse_varlen_with_cache(
*,
max_length: int = 16384,
HKV: int = 8,
HQ: int = 32,
D: int = 128,
PAGE_SIZE: int = 128,
device: str = "cuda",
dtype=torch.float16,
):
"""
Autotune causal_sparse_varlen_with_cache over a sweep of cache/append lengths.
"""
import itertools
import tqdm
N_LOGICAL_PAGES_MAX = ((max_length + PAGE_SIZE - 1) // PAGE_SIZE) * PAGE_SIZE
B = 4
# D must be a power of two (kernel requirement).
assert (D & (D - 1)) == 0
lengths_to_sweep = [0, 256]
i = 9
while (v := (1 << i)) < max_length:
lengths_to_sweep.append(v)
i += 1
combos = list(itertools.product(lengths_to_sweep, repeat=2))
logger.info(
"tuning kernels. this may take a few minutes, "
"but only needs to be run once per LLMConfig"
)
for cache_l, append_l in tqdm.tqdm(combos):
if cache_l + append_l == 0:
continue
L_cache_per_b = torch.tensor(
[cache_l] * B,
device=device,
dtype=torch.int32,
)
assert (L_cache_per_b <= PAGE_SIZE * N_LOGICAL_PAGES_MAX).all()
K_cache, V_cache, page_table, seq_lens_bh, CACHE_SIZE = (
build_mock_paged_cache_from_lengths(
L_cache_per_b=L_cache_per_b,
HKV=HKV,
D=D,
PAGE_SIZE=PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
device=device,
dtype=dtype,
)
)
L_app_list = [append_l] * B
cu = [0]
for L in L_app_list:
cu.append(cu[-1] + L)
cu_seqlens_qk = torch.tensor(cu, dtype=torch.int32, device=device)
N = int(cu_seqlens_qk[-1].item())
max_seqlen_q = int((cu_seqlens_qk[1:] - cu_seqlens_qk[:-1]).max().item())
max_seqlen_k = seq_lens_bh.max().item()
q_raw = torch.randn(N, HQ, D, device=device, dtype=dtype)
k_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
v_append_raw = torch.randn(N, HKV, D, device=device, dtype=dtype)
# Identity batch mapping (local batch index == global)
batch_mapping = torch.arange(B, device=device, dtype=torch.int32)
sm_scale = 1.0 / math.sqrt(D)
causal_sparse_varlen_with_cache(
q=q_raw,
k_cache=K_cache,
v_cache=V_cache,
k=k_append_raw,
v=v_append_raw,
seq_lens_bh=seq_lens_bh,
global_page_table=page_table,
batch_mapping=batch_mapping,
cu_seqlens_q=cu_seqlens_qk,
HKV=HKV,
PAGE_SIZE=PAGE_SIZE,
sm_scale=sm_scale,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_cache=max_seqlen_k,
)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Autotune Triton kernels. "
"Results are cached, so this should only need to be run once per configuration."
"This script doesn't need to be run, as the kernels will be autotuned at runtime"
"if no cached autotuning data exists. Running this before hand will prevent run-time"
"autotuning, which will accelerate compactor-vllm at inference time."
)
parser.add_argument(
"--max-length",
type=int,
default=16384,
help="Maximum total sequence length to consider.",
)
parser.add_argument(
"--HKV",
type=int,
default=8,
help="Number of KV heads.",
)
parser.add_argument(
"--HQ",
type=int,
default=32,
help="Number of query heads.",
)
parser.add_argument(
"--D",
type=int,
default=128,
help="Per-head hidden dimension (must be power of 2).",
)
parser.add_argument(
"--page-size",
type=int,
default=128,
help="Page size (tokens per physical page).",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Torch device to run on (e.g. 'cuda', 'cuda:0', 'cpu').",
)
parser.add_argument(
"--dtype",
type=str,
default="float16",
help="Dtype for tensors: one of {float16, fp16, bfloat16, bf16, float32, fp32}.",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG"],
help="Logging level.",
)
return parser.parse_args()
def _resolve_dtype(dtype_str: str):
s = dtype_str.lower()
if s in ("float16", "fp16", "half"):
return torch.float16
if s in ("bfloat16", "bf16"):
return torch.bfloat16
if s in ("float32", "fp32"):
return torch.float32
raise ValueError(f"Unsupported dtype: {dtype_str}")
def main():
args = _parse_args()
logging.basicConfig(
level=getattr(logging, args.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
dtype = _resolve_dtype(args.dtype)
logger.info(
"Starting autotune with max_length=%d, HKV=%d, HQ=%d, D=%d, page_size=%d, "
"device=%s, dtype=%s",
args.max_length,
args.HKV,
args.HQ,
args.D,
args.page_size,
args.device,
dtype,
)
autotune_causal_sparse_varlen_with_cache(
max_length=args.max_length,
HKV=args.HKV,
HQ=args.HQ,
D=args.D,
PAGE_SIZE=args.page_size,
device=args.device,
dtype=dtype,
)
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
)
main()
# SPDX-License-Identifier: Apache-2.0
"""FlashAttention paths over compactor paged KV (materialize + FA ops).
Used when :class:`~vllm.kvprune.config.engine_config.KvpruneAttentionSchedule`
selects FlashAttention for prefill and/or decode while KV **writes** remain on
Triton (``prefill_store_*``, ``decode_store_kv``). Matches the reference checks
in ``vllm/compactor-vllm/tests/test_triton_attention.py``.
"""
from __future__ import annotations
import math
from typing import TYPE_CHECKING
import torch
from flash_attn.flash_attn_interface import flash_attn_func, flash_attn_varlen_func
if TYPE_CHECKING:
pass
def materialize_kv_for_flash_prefill(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor,
batch_mapping: torch.Tensor,
L_cache_per_b: torch.Tensor,
k_append: torch.Tensor,
v_append: torch.Tensor,
cu_seqlens_q: torch.Tensor,
H_kv: int,
PAGE_SIZE: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Build packed K/V for :func:`flash_attn_varlen_func` (cache prefix + append)."""
device = k_cache.device
dtype = k_cache.dtype
B = cu_seqlens_q.numel() - 1
N, H_kv_raw, D = k_append.shape
assert H_kv_raw == H_kv
L_app = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).to(torch.int32)
seqlen_k = L_cache_per_b.to(torch.int32) + L_app
cu_seqlens_k = torch.empty(B + 1, device=device, dtype=torch.int32)
cu_seqlens_k[0] = 0
total_k = int(seqlen_k.sum().item())
K_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
V_total = torch.empty((total_k, H_kv, D), device=device, dtype=dtype)
for b in range(B):
offset_k = int(cu_seqlens_k[b].item())
Lc = int(L_cache_per_b[b].item())
La = int(L_app[b].item())
q_start = int(cu_seqlens_q[b].item())
b_true = int(batch_mapping[b].item())
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b_true, g, lp].item())
idx = phys * PAGE_SIZE + off
K_total[offset_k + i, g] = k_cache[idx]
V_total[offset_k + i, g] = v_cache[idx]
for g in range(H_kv):
for j in range(La):
src = q_start + j
dst = offset_k + Lc + j
K_total[dst, g] = k_append[src, g]
V_total[dst, g] = v_append[src, g]
cu_seqlens_k[b + 1] = cu_seqlens_k[b] + (Lc + La)
return K_total, V_total, cu_seqlens_k
def flash_prefill_from_paged(
q: torch.Tensor,
k_append: torch.Tensor,
v_append: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
*,
seq_lens_bh_before: torch.Tensor,
global_page_table: torch.Tensor,
batch_mapping: torch.Tensor,
cu_seqlens_q: torch.Tensor,
max_seqlen_q: int,
PAGE_SIZE: int,
HKV: int,
sm_scale: float | None,
) -> torch.Tensor:
"""Prefill attention via FlashAttention-2 varlen after materializing paged KV + append."""
L_cache_per_b = seq_lens_bh_before.max(dim=1).values.to(torch.int32)
K_total, V_total, cu_seqlens_k = materialize_kv_for_flash_prefill(
k_cache,
v_cache,
global_page_table,
batch_mapping,
L_cache_per_b,
k_append,
v_append,
cu_seqlens_q,
HKV,
PAGE_SIZE,
)
max_seqlen_k = int((cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item())
return flash_attn_varlen_func(
q,
K_total,
V_total,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=sm_scale if sm_scale is not None else None,
causal=True,
)
def materialize_kv_cache_for_flash_decode(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
page_table: torch.Tensor,
batch_mapping: torch.Tensor,
L_cache_per_b: torch.Tensor,
H_kv: int,
PAGE_SIZE: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Dense ``[B, S, H_kv, D]`` cache for :func:`flash_attn_func` decode."""
device = k_cache.device
dtype = k_cache.dtype
B = L_cache_per_b.shape[0]
D = k_cache.shape[1]
seqlen_cache_max = int(L_cache_per_b.max().item())
K_flash = torch.zeros((B, seqlen_cache_max, H_kv, D), device=device, dtype=dtype)
V_flash = torch.zeros_like(K_flash)
for b in range(B):
Lc = int(L_cache_per_b[b].item())
if Lc == 0:
continue
b_true = int(batch_mapping[b].item())
for g in range(H_kv):
for i in range(Lc):
lp = i // PAGE_SIZE
off = i % PAGE_SIZE
phys = int(page_table[b_true, g, lp].item())
idx = phys * PAGE_SIZE + off
K_flash[b, i, g] = k_cache[idx]
V_flash[b, i, g] = v_cache[idx]
return K_flash, V_flash
def flash_decode_from_paged(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
*,
seq_lens_bh: torch.Tensor,
global_page_table: torch.Tensor,
batch_mapping: torch.Tensor,
PAGE_SIZE: int,
HKV: int,
sm_scale: float | None,
) -> torch.Tensor:
"""Decode step via FA: ``decode_store_kv`` has already appended the new K/V row."""
L_cache_per_b = seq_lens_bh.max(dim=1).values.to(torch.int32)
K_flash, V_flash = materialize_kv_cache_for_flash_decode(
k_cache,
v_cache,
global_page_table,
batch_mapping,
L_cache_per_b,
HKV,
PAGE_SIZE,
)
B, HQ, D = q.shape
q_b = q.unsqueeze(1)
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# One query position attends to all L keys already materialized in K/V (no causal mask).
out = flash_attn_func(
q_b,
K_flash,
V_flash,
softmax_scale=sm_scale,
causal=False,
)
return out.squeeze(1)
import functools
import math
import torch
import triton
import triton.language as tl
from vllm.kvprune.utils.triton_compat import (
autotune as triton_autotune,
maybe_set_allocator,
)
def head_sparse_decode_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
seq_lens_bh: torch.Tensor,
global_page_table: torch.Tensor,
batch_mapping: torch.Tensor,
HKV: int,
PAGE_SIZE: int,
sm_scale: float = None,
key_split: int = None,
):
"""
Decode-time head-sparse attention over a paged KV cache.
This is a wrapper around the Triton decode kernel used during incremental
generation. For each batch, we read the cached keys
and values from a global paged KV buffer, apply causal attention with one
new query token, and return the attention output.
The KV cache is stored in a single global K/V tensor of shape
``[CACHE_SIZE, D]`` and indexed via a per-layer page table. Each logical
(batch, kv_head, token_idx) is mapped to a physical row in the cache by:
1. Looking up the logical page index in ``global_page_table[b, h, lp]``,
2. Computing ``phys_row = page_id * PAGE_SIZE + (token_idx % PAGE_SIZE)``.
Grouped-query attention (GQA / MQA) is supported by passing more query
heads than KV heads (``HQ`` must be a multiple of ``HKV``).
Args:
:param q: Query tensor of shape ``[B, HQ, D]`` or `[B, 1, HQ, D]`
containing the new decode tokens for each sequence in the launch batch.
:param k: Global key cache of shape ``[CACHE_SIZE, D]``. This is the shared
backing buffer for all (batch, head) KV pages.
:param v: Global value cache of shape ``[CACHE_SIZE, D]``.
:param seq_lens_bh: Tensor of shape ``[B, HKV]`` (int32) giving, for each
local batch index and KV head, the number of valid cached tokens
in the paged KV cache.
:param global_page_table: Tensor of shape
``[MAX_NUM_BATCHES, HKV, N_LOGICAL_PAGES_MAX]`` (int32) mapping
``(true_batch_idx, kv_head, logical_page)`` to a physical page id
in the global cache.
:param batch_mapping: Tensor of shape ``[B]`` (int32) mapping the launch-batch
index used by this call to the true batch row used to index
``global_page_table``.
:param HKV: Number of KV heads.
:param PAGE_SIZE: Number of tokens stored per physical KV page.
:param sm_scale: Optional scaling factor applied to the attention logits
before softmax. If ``None``, ``1 / sqrt(D)`` is used.
:param key_split: Optional number of splits along the key sequence length.
If > 1, the kernel will process the KV sequence in ``key_split``
chunks to reduce on-chip memory usage. If ``None`` or 0, a
heuristic is used.
Returns:
:return torch.Tensor: Attention output of shape ``[B, HQ, D]`` on the same
device and dtype as ``q``.
"""
with torch.cuda.device(q.device):
if q.ndim != 3:
assert q.ndim == 4
B, HQ, S, D = q.shape
assert S == 1, "head_sparse_decode_attention only supports q_len=1"
q = q.squeeze(-2)
elif q.ndim == 3:
B, HQ, D = q.shape
CACHE_SIZE = k.shape[0]
assert PAGE_SIZE % 32 == 0, "PAGE_SIZE must be divisible by 32"
GROUP_M = HQ // HKV
assert GROUP_M * HKV == HQ, "HQ must be divisible by H_kv"
FP8 = hasattr(torch, "float8_e5m2") and q.dtype == torch.float8_e5m2
seq_lens_bh = seq_lens_bh.to(torch.int32)
assert B <= 32767, "too many batches"
assert global_page_table.shape[1] == HKV
assert q.is_contiguous()
assert (D & (D - 1)) == 0, "D must be a power of 2"
N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
sm_scale = 1 / math.sqrt(D) if sm_scale is None else sm_scale
if key_split is None:
# round max_seq_len to the next power of two to maximize cache hits
key_split = num_splits_heuristic(
B * HKV,
max_seq_len=1 << int(seq_lens_bh.max()).bit_length(),
num_sms=torch.cuda.get_device_properties(
q.device
).multi_processor_count,
max_splits=12,
)
maybe_set_allocator(
lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
)
# stage 1 scratch
mid_o = torch.empty((B, key_split, HQ, D), device=q.device, dtype=q.dtype)
mid_lse = torch.empty((B, key_split, HQ), device=q.device, dtype=torch.float32)
# processes all queries for a KV head together
# pointers are lowercase, CONSTANTS are upper
grid1 = (B, HKV, key_split)
_varkv_stage1_groupM[grid1](
q=q,
k=k,
v=v,
mid_o=mid_o,
mid_lse=mid_lse,
page_table_bhl=global_page_table,
batch_mapping=batch_mapping,
seq_lens_bh=seq_lens_bh.contiguous(),
SM_SCALE=sm_scale,
B=B,
HKV=HKV,
HQ=HQ,
CACHE_SIZE=CACHE_SIZE,
STRIDE_LBS=mid_lse.stride(0),
STRIDE_LS=mid_lse.stride(1),
STRIDE_LH=mid_lse.stride(2),
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
D=D,
KEY_SPLIT=key_split,
GROUP_M=GROUP_M,
DTYPE=tl.float8e5
if FP8
else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
PAGE_SIZE=PAGE_SIZE,
)
if key_split == 1:
return mid_o.squeeze(1).contiguous()
# reduce partial results across splits
output = torch.empty_like(q)
grid2 = (B, HQ)
_varkv_stage2_reduce[grid2](
mid_o=mid_o,
mid_lse=mid_lse,
output=output,
STRIDE_LBS=mid_lse.stride(0),
STRIDE_LS=mid_lse.stride(1),
STRIDE_LH=mid_lse.stride(2),
STRIDE_OBS=output.stride(0),
STRIDE_OH=output.stride(1),
B=B,
HQ=HQ,
D=D, # type: ignore
KEY_SPLIT=key_split, # type: ignore
DTYPE=tl.float8e5
if FP8
else (tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16),
)
return output
# similar to flash attention split heuristic
@functools.lru_cache(maxsize=128)
def num_splits_heuristic(
total_mblocks: int,
max_seq_len: int,
num_sms: int,
max_splits: int,
) -> int:
# If we nearly fill SMs already, prefer 1 split
if total_mblocks >= 0.8 * num_sms or max_seq_len <= 1024:
return 1
eff = []
max_eff = 0.0
for s in range(1, min(max_splits, num_sms) + 1):
if (max_seq_len / s) <= 512:
break
n_waves = float(total_mblocks * s) / float(num_sms)
e = n_waves / math.ceil(n_waves) if n_waves > 0 else 0.0
eff.append(e)
max_eff = max(max_eff, e)
threshold = 0.75 * max_eff # if not split_min_hit else 0.9 * max_eff
for i, e in enumerate(eff, start=1):
if e >= threshold:
return i
return 1
def prune_invalid_configs(configs, _, **kwargs):
PAGE_SIZE = kwargs["PAGE_SIZE"]
return [conf for conf in configs if conf.kwargs.get("BLOCK_N", 0) <= PAGE_SIZE]
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_N": BLOCK_N, "MIN_BLOCK_KV": MIN_BLOCK_KV, "WARPSPEC": ws},
num_warps=w,
num_stages=s,
)
for BLOCK_N in [32, 64, 128]
for MIN_BLOCK_KV in [8]
for s in [2, 3, 4]
for w in [4, 8]
for ws in [True, False]
],
key=[
"HKV",
"GROUP_M",
"D",
"PAGE_SIZE", # "B"
],
cache_results=True,
prune_configs_by={"early_config_prune": prune_invalid_configs},
)
@triton.jit
def _varkv_stage1_groupM(
q, # [B, HQ, D] contiguous
k, # GLOBAL cache: [CACHE_SIZE, D], contiguous
v, # GLOBAL cache: [CACHE_SIZE, D], contiguous
mid_o,
mid_lse,
page_table_bhl, # int32 [B*H_kv*N_LOGICAL_PAGES_MAX] (flattened)
batch_mapping, # int32 [B] maps local pid_b -> true batch index
seq_lens_bh, # int32 [B*H_kv] valid tokens per (b,h)
SM_SCALE,
B,
HKV,
HQ,
CACHE_SIZE, # CACHE_SIZE = N_PAGES * PAGE_SIZE
STRIDE_LBS,
STRIDE_LS,
STRIDE_LH,
# constexprs
N_LOGICAL_PAGES_MAX: tl.constexpr, # page table width per (b,h)
D: tl.constexpr,
KEY_SPLIT: tl.constexpr,
GROUP_M: tl.constexpr,
DTYPE: tl.constexpr,
BLOCK_N: tl.constexpr,
MIN_BLOCK_KV: tl.constexpr,
WARPSPEC: tl.constexpr,
PAGE_SIZE: tl.constexpr,
):
pid_b = tl.program_id(0) # batch
pid_kvh = tl.program_id(1) # kv head
pid_s = tl.program_id(2) # split
# valid length L for this (b,h)
bh_stride = HKV
L = tl.load(seq_lens_bh + pid_b * bh_stride + pid_kvh)
if L == 0:
return
tl.assume(L > 0)
# split sizing on logical token axis [0..L)
base = tl.cdiv(L, KEY_SPLIT)
per_split_len = tl.cdiv(base, MIN_BLOCK_KV) * MIN_BLOCK_KV
split_start = pid_s * per_split_len
split_end = tl.minimum(split_start + per_split_len, L)
# query heads mapped to this kv head
base_qh = pid_kvh * GROUP_M
GROUP_M_PAD: tl.constexpr = 16 if GROUP_M < 16 else GROUP_M
offs_m = tl.arange(0, GROUP_M_PAD)
mask_m = offs_m < GROUP_M
offs_d = tl.arange(0, D)
# load Q tile [M, D]
q_ptrs = q + (pid_b * HQ + base_qh + offs_m)[:, None] * D + offs_d[None, :]
q = tl.load(q_ptrs, mask=mask_m[:, None], other=0.0).to(DTYPE) # [M, D]
# streaming softmax state per query
e_max = tl.zeros([GROUP_M_PAD], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([GROUP_M_PAD], dtype=tl.float32)
acc = tl.zeros([GROUP_M_PAD, D], dtype=tl.float32)
if split_end > split_start:
# logical pages covering [split_start, split_end)
lp0 = split_start // PAGE_SIZE
lp1 = tl.cdiv(split_end, PAGE_SIZE) # exclusive
mapped_b = tl.load(batch_mapping + pid_b)
tl.assume(mapped_b >= 0)
# page table base for this (b,h)
pt_stride = N_LOGICAL_PAGES_MAX
pt_base = (mapped_b * HKV + pid_kvh) * pt_stride
for lp in tl.range(lp0, lp1):
phys = tl.load(
page_table_bhl + pt_base + lp, cache_modifier=".cg"
) # physical page id
# bounds within the logical page
local_start = tl.where(lp == lp0, split_start - lp * PAGE_SIZE, 0)
local_end = tl.where(lp == (lp1 - 1), split_end - lp * PAGE_SIZE, PAGE_SIZE)
page_base = phys * PAGE_SIZE
page_base = tl.multiple_of(page_base, BLOCK_N)
for s in tl.range(local_start, local_end, BLOCK_N):
s = tl.multiple_of(s, MIN_BLOCK_KV)
offs_bn = tl.arange(0, BLOCK_N)
key_idx = page_base + s + offs_bn
k_ptrs = k + key_idx[:, None] * D + offs_d[None, :]
k_blk = tl.load(k_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
qk = tl.dot(q, k_blk.T) * SM_SCALE # [M, BN]
offs_n = s + tl.arange(0, BLOCK_N)
mask_n = offs_n < local_end
qk = tl.where(mask_n[None, :], qk, -float("inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max) # [M]
re_scale = tl.exp(e_max - n_e_max) # [M]
acc = acc * re_scale[:, None] # [M, D]
v_ptrs = v + key_idx[:, None] * D + offs_d[None, :]
v_blk = tl.load(v_ptrs, mask=(key_idx < CACHE_SIZE)[:, None], other=0.0)
p = tl.exp(qk - n_e_max[:, None]) # [M, BN]
acc = tl.dot(p.to(DTYPE), v_blk, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# write mid outputs [M, D] for this split
tmp = (acc / e_sum[:, None]).to(DTYPE)
row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
tl.store(mid_ptrs, tmp, mask=mask_m[:, None])
ml_ptrs = (
mid_lse
+ pid_b * STRIDE_LBS
+ pid_s * STRIDE_LS
+ (base_qh + offs_m) * STRIDE_LH
)
safe_sum = tl.where(mask_m, e_sum, 1.0)
tl.store(ml_ptrs, e_max + tl.log(safe_sum), mask=mask_m)
else:
# empty split
zero_md = tl.zeros([GROUP_M_PAD, D], dtype=DTYPE)
row_mid = pid_b * (KEY_SPLIT * HQ) + pid_s * HQ + base_qh + offs_m
mid_ptrs = mid_o + row_mid[:, None] * D + offs_d[None, :]
tl.store(mid_ptrs, zero_md, mask=mask_m[:, None])
ml_ptrs = (
mid_lse
+ pid_b * STRIDE_LBS
+ pid_s * STRIDE_LS
+ (base_qh + offs_m) * STRIDE_LH
)
tl.store(ml_ptrs, -float("inf"), mask=mask_m)
@triton.jit
def _varkv_stage2_reduce(
mid_o,
mid_lse,
output,
STRIDE_LBS,
STRIDE_LS,
STRIDE_LH,
STRIDE_OBS,
STRIDE_OH,
B,
HQ,
D: tl.constexpr,
KEY_SPLIT: tl.constexpr,
DTYPE: tl.constexpr,
):
pid_b = tl.program_id(0)
pid_h = tl.program_id(1)
offs_d = tl.arange(0, D)
# across split LSE combine
e_sum = 0.0
e_max = -float("inf")
acc = tl.zeros([D], dtype=tl.float32)
for s in tl.range(KEY_SPLIT):
row_mid = pid_b * (KEY_SPLIT * HQ) + s * HQ + pid_h
tv = tl.load(mid_o + row_mid * D + offs_d).to(DTYPE)
tl_ptr = mid_lse + pid_b * STRIDE_LBS + s * STRIDE_LS + pid_h * STRIDE_LH
tlogic = tl.load(tl_ptr)
n_e_max = tl.maximum(e_max, tlogic)
old_scale = tl.exp(e_max - n_e_max)
acc = acc * old_scale + tl.exp(tlogic - n_e_max) * tv.to(tl.float32)
e_sum = e_sum * old_scale + tl.exp(tlogic - n_e_max)
e_max = n_e_max
o = (acc / e_sum).to(DTYPE)
o_ptr = output + pid_b * STRIDE_OBS + pid_h * STRIDE_OH + offs_d
tl.store(o_ptr, o)
import logging
import math
import torch
import triton
import triton.language as tl
from vllm.kvprune.utils.triton_compat import (
autotune as triton_autotune,
cuda_capability_geq,
maybe_set_allocator,
)
logger = logging.getLogger(__name__)
def causal_sparse_varlen_with_cache(
q,
k,
v,
k_cache,
v_cache,
seq_lens_bh,
global_page_table,
batch_mapping,
cu_seqlens_q,
max_seqlen_q: int,
max_seqlen_k_cache: int,
HKV: int,
PAGE_SIZE: int,
sm_scale=None,
):
"""
Causal prefill attention over a paged KV cache plus a block of newly
appended tokens in a packed batch format.
This function wraps the Triton kernel
``_causal_head_sparse_varlen_with_cache`` to compute prefill attention for
a batch of variable-length sequences, where:
• Past keys/values are stored in a paged global KV cache
(``k_cache``, ``v_cache``) and indexed via ``global_page_table``.
• New tokens for this step are given as K/V blocks (``k``, ``v``)
together with a packed query block ``q``.
Grouped-query attention (GQA / MQA) is supported: ``HQ`` must be divisible
by ``HKV``.
"""
assert q.ndim == 3, "q should be [N, HQ, D]"
N, HQ, D = q.shape
assert (D & (D - 1)) == 0, "D must be power of two"
B = cu_seqlens_q.numel() - 1
assert B > 0
assert HQ % HKV == 0, "Number of query heads must divide number of keys heads"
H_g = HQ // HKV
# view Q as [HKV, N, QUERY_GROUP_SIZE, D]
out = torch.empty_like(q)
q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
out = out.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
# K_app/V_app: [N, HKV, D] -> [HKV, N, D]
k_app = k.view(N, HKV, D).permute(1, 0, 2)
v_app = v.view(N, HKV, D).permute(1, 0, 2)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=q.device)
seq_lens_bh = seq_lens_bh.to(dtype=torch.int32, device=q.device)
batch_mapping = batch_mapping.to(dtype=torch.int16, device=q.device)
N_LOGICAL_PAGES_MAX = global_page_table.shape[-1]
CACHE_SIZE = k_cache.shape[0]
assert v_cache.shape[0] == CACHE_SIZE
assert k_cache.shape[1] == D and v_cache.shape[1] == D
assert PAGE_SIZE > 0 and CACHE_SIZE % PAGE_SIZE == 0
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
# strides for Q [G, N, QUERY_GROUP_SIZE, D]
STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
STRIDE_KC, STRIDE_VC = k_cache.stride(0), v_cache.stride(0)
# [G, N, D]
STRIDE_KA_G, STRIDE_KA_N, STRIDE_KA_D = k_app.stride()
STRIDE_VA_G, STRIDE_VA_N, STRIDE_VA_D = v_app.stride()
# OUT [G, N, QUERY_GROUP_SIZE, D]
STRIDE_OUT_G, STRIDE_OUT_N, STRIDE_OUT_H, STRIDE_OUT_D = out.stride()
# launch grid
maybe_set_allocator(
lambda size, align, _: torch.empty(size, dtype=torch.int8, device=q.device)
)
assert STRIDE_KA_D == STRIDE_VA_D == STRIDE_Q_D == STRIDE_OUT_D == 1, (
"final dimension must be contiguous"
)
def grid(META):
return HKV, B, triton.cdiv(max_seqlen_q, META["BLOCK_M"])
# On a fresh batch, max_seqlen_k_cache==0 (no KV prefix yet). Passing
# `triton.next_power_of_2(0)` into autotune constexpr keys breaks
# kernel selection / tuning and can yield garbage outputs.
_k_max_autotune = max(int(max_seqlen_k_cache), 1)
AUTOTUNE_MAX_Q_LEN = triton.next_power_of_2(max_seqlen_q)
AUTOTUNE_MAX_K_LEN = triton.next_power_of_2(_k_max_autotune)
_causal_head_sparse_varlen_with_cache[grid](
Q=q,
K_cache=k_cache,
V_cache=v_cache,
K_app=k_app,
V_app=v_app,
cu_seqlens_qk=cu_seqlens_q,
seq_lens_bh=seq_lens_bh,
page_table=global_page_table,
batch_mapping=batch_mapping,
OUT=out,
HKV=HKV,
QUERY_GROUP_SIZE=H_g,
PAGE_SIZE=PAGE_SIZE,
N_LOGICAL_PAGES_MAX=N_LOGICAL_PAGES_MAX,
STRIDE_Q_G=STRIDE_Q_G,
STRIDE_Q_N=STRIDE_Q_N,
STRIDE_Q_H=STRIDE_Q_H,
STRIDE_KC=STRIDE_KC,
STRIDE_VC=STRIDE_VC,
STRIDE_KA_G=STRIDE_KA_G,
STRIDE_KA_N=STRIDE_KA_N,
STRIDE_VA_G=STRIDE_VA_G,
STRIDE_VA_N=STRIDE_VA_N,
STRIDE_OUT_G=STRIDE_OUT_G,
STRIDE_OUT_N=STRIDE_OUT_N,
STRIDE_OUT_H=STRIDE_OUT_H,
sm_scale=sm_scale,
D=D,
AUTOTUNE_MAX_Q_LEN=AUTOTUNE_MAX_Q_LEN,
AUTOTUNE_MAX_K_LEN=AUTOTUNE_MAX_K_LEN,
)
return out.permute(1, 0, 2, 3).view(N, HQ, D) # already contiguous
autotune_configs_cc9 = [
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=16, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 64, "WARPSPEC": True}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=4, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": True}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 64, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": True}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 32, "WARPSPEC": False}, num_warps=8, num_stages=4
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=8, num_stages=3
),
triton.Config(
{"BLOCK_N": 32, "BLOCK_M": 16, "WARPSPEC": False}, num_warps=4, num_stages=4
),
]
autotune_configs_cc8 = [
triton.Config(
{"BLOCK_N": BN, "BLOCK_M": BM, "WARPSPEC": True}, num_warps=w, num_stages=s
)
for BN in [16, 32]
for BM in [64]
for w in [4, 8]
for s in [2, 3]
]
def prune_invalid_configs(configs, _, **kwargs):
return [
conf
for conf in configs
if not (conf.kwargs.get("BLOCK_N") == 32 and conf.kwargs.get("num_stages") == 4)
]
def get_autotune_configs():
if cuda_capability_geq(9, 0):
return autotune_configs_cc9
else:
return autotune_configs_cc8
@triton_autotune(
configs=get_autotune_configs(),
key=[
"HKV",
"QUERY_GROUP_SIZE",
"D",
"PAGE_SIZE",
"AUTOTUNE_MAX_K_LEN",
"AUTOTUNE_MAX_Q_LEN",
],
cache_results=True,
)
@triton.jit
def _causal_head_sparse_varlen_with_cache(
Q, # [HKV, N, QUERY_GROUP_SIZE, D] (non-contiguous)
K_cache,
V_cache, # [CACHE_SIZE, D]
K_app,
V_app, # [HKV, N, D]
cu_seqlens_qk, # [B+1]
seq_lens_bh, # [B, HKV]
page_table, # [B_total, HKV, N_LOGICAL_PAGES_MAX]
batch_mapping, # [B], maps local b -> global batch index
OUT, # [HKV, N, QUERY_GROUP_SIZE, D]
#
HKV: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
PAGE_SIZE: tl.constexpr,
N_LOGICAL_PAGES_MAX,
STRIDE_Q_G,
STRIDE_Q_N,
STRIDE_Q_H,
STRIDE_KC,
STRIDE_VC,
STRIDE_KA_G,
STRIDE_KA_N,
STRIDE_VA_G,
STRIDE_VA_N,
STRIDE_OUT_G,
STRIDE_OUT_N,
STRIDE_OUT_H,
sm_scale,
#
D: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
WARPSPEC: tl.constexpr,
AUTOTUNE_MAX_Q_LEN: tl.constexpr, # used for autotune key
AUTOTUNE_MAX_K_LEN: tl.constexpr, # used for autotune key
):
TOTAL_N_QUERIES: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
pid_g = tl.program_id(0) # kv_head id in [0, HKV)
pid_b = tl.program_id(1) # batch id
pid_m = tl.program_id(2) # query-tile id within batch
# batch segment [qb, qe) in N
off_b = tl.load(cu_seqlens_qk + pid_b)
off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
seq_len_append = off_b1 - off_b
q_start = off_b + pid_m * BLOCK_M
q_end = tl.minimum(q_start + BLOCK_M, off_b1)
# number of queries in this tile for this batch
M = q_end - q_start
if M <= 0:
return
# cached length for (b, kv_head=pid_g)
L_cache = tl.load(seq_lens_bh + pid_b * HKV + pid_g)
# row indices flattened over [QUERY_GROUP_SIZE, M]
offs_row = tl.arange(0, TOTAL_N_QUERIES)
row_m = offs_row % BLOCK_M
row_h = offs_row // BLOCK_M
# valid rows: only those with row_m < M
row_mask = row_m < M
# global query index per row
q_idx = q_start + row_m
offs_d = tl.arange(0, D)
# Q tile: [TOTAL_N_QUERIES, D]
# Q layout: [HKV, N, QUERY_GROUP_SIZE, D]
q_ptrs = (
Q
+ pid_g * STRIDE_Q_G
+ q_idx[:, None] * STRIDE_Q_N
+ row_h[:, None] * STRIDE_Q_H
+ offs_d[None, :]
)
q = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
e_max = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32) - float("inf")
e_sum = tl.zeros([TOTAL_N_QUERIES], dtype=tl.float32)
acc = tl.zeros([TOTAL_N_QUERIES, D], dtype=tl.float32)
offs_block_n = tl.arange(0, BLOCK_N)
qk_scale = sm_scale * 1.44269504
# 1) attend over cachee K/V
if L_cache > 0:
# map local (b) to global batch index
mapped_b = tl.load(batch_mapping + pid_b)
pt_base = (mapped_b * HKV + pid_g) * N_LOGICAL_PAGES_MAX
# iterate logical pages
num_lp = tl.cdiv(L_cache, PAGE_SIZE)
for lp in tl.range(0, num_lp):
# can overflow in 32 bits so upcast
phys = tl.load(page_table + pt_base + lp).to(tl.int64)
page_start = phys * PAGE_SIZE
# how many valid tokens in this page for this (b,g)
remain = L_cache - lp * PAGE_SIZE
page_len = tl.minimum(PAGE_SIZE, remain)
# iterate over this page in BLOCK_N chunks
for ks in tl.range(0, page_len, BLOCK_N):
offs_n = ks + offs_block_n
mask_n = offs_n < page_len
key_idx = page_start + offs_n
k_ptrs = K_cache + key_idx[:, None] * STRIDE_KC + offs_d[None, :]
k = tl.load(k_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
qk = tl.dot(q, k.T) * qk_scale # [TOTAL_N_QUERIES, BN]
qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
# softmax update
cur_max = tl.max(qk, 1)
n_e_max = tl.maximum(e_max, cur_max)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max[:, None])
v_ptrs = V_cache + key_idx[:, None] * STRIDE_VC + offs_d[None, :]
v = tl.load(v_ptrs, mask=mask_n[:, None], other=0.0) # [BN, D]
acc = acc * re_scale[:, None]
acc = tl.dot(p.to(v.dtype), v, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# 2) attend over appended K_app/V_app (causal)
# appended tokens for batch b are in [off_b, off_b1)
# query tile is [q_start, q_end)
# for each query at index q_idx, valid appended keys k satisfy off_b <= k <= q_idx
if q_end > off_b:
# exactly one appended token
if seq_len_append == 1:
ka_ptrs = K_app + pid_g * STRIDE_KA_G + off_b * STRIDE_KA_N + offs_d
k = tl.load(ka_ptrs) # [D]
qk = tl.sum(q * k[None, :], 1) * qk_scale
qk = tl.where(row_mask, qk, -1.0e6)
n_e_max = tl.maximum(e_max, qk)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max)
va_ptrs = V_app + pid_g * STRIDE_VA_G + off_b * STRIDE_VA_N + offs_d
v = tl.load(va_ptrs) # [D]
acc = acc * re_scale[:, None] + p[:, None] * v[None, :]
e_sum = e_sum * re_scale + p
else:
# off-band: k in [off_b, q_start)
# for all queries t in [q_start, q_end), any k < q_start satisfies k <= t.
# so no causal mask needed.
off_band_start = off_b
off_band_end = q_start
if off_band_end > off_band_start:
for ks in tl.range(off_band_start, off_band_end, BLOCK_N):
offs_n = ks + offs_block_n
mask_n = offs_n < off_band_end
ka_ptrs = (
K_app
+ pid_g * STRIDE_KA_G
+ offs_n[:, None] * STRIDE_KA_N
+ offs_d[None, :]
)
k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
qk = tl.dot(q, k.T) * qk_scale
qk = tl.where(row_mask[:, None] & mask_n[None, :], qk, -1.0e6)
cur_max = tl.max(qk, 1)
n_e_max = tl.maximum(e_max, cur_max)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max[:, None])
va_ptrs = (
V_app
+ pid_g * STRIDE_VA_G
+ offs_n[:, None] * STRIDE_VA_N
+ offs_d[None, :]
)
v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
acc = acc * re_scale[:, None]
acc = tl.dot(p.to(v.dtype), v, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# on-band remaining k
on_band_start = tl.maximum(q_start, off_b)
if on_band_start < q_end:
for ks in tl.range(on_band_start, q_end, BLOCK_N):
offs_n = ks + tl.arange(0, BLOCK_N)
mask_n = offs_n < q_end
ka_ptrs = (
K_app
+ pid_g * STRIDE_KA_G
+ offs_n[:, None] * STRIDE_KA_N
+ offs_d[None, :]
)
k = tl.load(ka_ptrs, mask=mask_n[:, None], other=0.0)
qk = tl.dot(q, k.T) * qk_scale
caus_mask = offs_n[None, :] <= q_idx[:, None]
full_mask = row_mask[:, None] & mask_n[None, :] & caus_mask
qk = tl.where(full_mask, qk, -1.0e6)
cur_max = tl.max(qk, 1)
n_e_max = tl.maximum(e_max, cur_max)
re_scale = tl.math.exp2(e_max - n_e_max)
p = tl.math.exp2(qk - n_e_max[:, None])
va_ptrs = (
V_app
+ pid_g * STRIDE_VA_G
+ offs_n[:, None] * STRIDE_VA_N
+ offs_d[None, :]
)
v = tl.load(va_ptrs, mask=mask_n[:, None], other=0.0)
acc = acc * re_scale[:, None]
acc = tl.dot(p.to(v.dtype), v, acc)
e_sum = e_sum * re_scale + tl.sum(p, 1)
e_max = n_e_max
# 3) write outputs
o = (acc / e_sum[:, None]).to(q.dtype)
out_ptrs = (
OUT
+ pid_g * STRIDE_OUT_G
+ q_idx[:, None] * STRIDE_OUT_N
+ row_h[:, None] * STRIDE_OUT_H
+ offs_d[None, :]
)
tl.store(out_ptrs, o, mask=row_mask[:, None])
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark helpers for kv-prune / compactor kernels.
Upstream snapshot (``compactor-vllm/src/compactor_vllm/benchmark``) contained **only**
an empty ``__init__.py`` — no additional ``.py`` scripts. Those files are merged here
as-is; there is nothing else to list under that directory in upstream.
Use :data:`BENCHMARK_REGISTRY` to register microbenchmarks or CLI entrypoints you
add under ``vllm.kvprune.benchmark``.
"""
from __future__ import annotations
from typing import Any, Callable
# Files copied from upstream ``compactor_vllm/benchmark/`` (relative to that dir).
UPSTREAM_BENCHMARK_FILES: tuple[str, ...] = ("__init__.py",)
# Optional: name -> benchmark callable or import path string (e.g. "mymod:main").
# Populated when you add real benchmarks beside this package.
BENCHMARK_REGISTRY: dict[str, Callable[..., Any] | str] = {}
def list_upstream_benchmark_files() -> tuple[str, ...]:
"""Return the list of filenames that existed in upstream ``benchmark/``."""
return UPSTREAM_BENCHMARK_FILES
def register_benchmark(name: str, target: Callable[..., Any] | str) -> None:
"""Register a benchmark by name (callable or ``"module:attr"`` import path)."""
BENCHMARK_REGISTRY[name] = target
def iter_registered_benchmarks() -> list[tuple[str, Callable[..., Any] | str]]:
"""Return ``(name, target)`` pairs from :data:`BENCHMARK_REGISTRY`."""
return list(BENCHMARK_REGISTRY.items())
__all__ = [
"BENCHMARK_REGISTRY",
"UPSTREAM_BENCHMARK_FILES",
"iter_registered_benchmarks",
"list_upstream_benchmark_files",
"register_benchmark",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layout notes: ``vllm/compactor-vllm/src/compactor_vllm`` (or sibling tree) →
``vllm.kvprune.<subdir>``.
The upstream tree is merged into parallel subpackages under ``vllm/kvprune/``
(``attention``, ``kv_cache``, ``compression``, ``config``, ``core``, ``layers``,
``models``, ``triton_kernels``, ``utils``, ``benchmark``). Imports use
``from vllm.kvprune.<module>.*``.
v1 integration (FlashAttention, ``gpu_model_runner``) lives in
``core.runtime``, ``core.flash_integration``, and ``compression/prefill.py``.
**Note:** filenames with hyphens under ``compression/`` are not importable as
Python modules; rename or load via ``importlib`` if needed.
**TP / embedding in vLLM workers:** upstream compactor-vllm used only
``vllm.kvprune`` ``ParallelLMHead`` + ``dist.gather``. When embedded in v1 workers,
prefer ``delegate_kvprune_embed_tokens_to_vllm`` and
``delegate_kvprune_compute_logits_to_vllm`` so token masking and logits match
``vocab_parallel_embedding`` + ``LogitsProcessor`` (garbled text often came from
TP gather / padded-vocab handling, not from the transformer body).
"""
from __future__ import annotations
import pathlib
def kvprune_root() -> pathlib.Path:
"""Absolute path to ``vllm/kvprune``."""
return pathlib.Path(__file__).resolve().parent
def list_py_files() -> list[str]:
"""Relative paths of all ``.py`` files under ``kvprune`` (excluding __pycache__)."""
root = kvprune_root()
return sorted(
str(p.relative_to(root)).replace("\\", "/")
for p in root.rglob("*.py")
if "__pycache__" not in p.parts
)
def format_layout_report() -> str:
files = list_py_files()
lines = [
"vllm.kvprune — merged compactor layout",
f"python file count: {len(files)}",
"=" * 50,
*files[:250],
]
if len(files) > 250:
lines.append(f"... and {len(files) - 250} more")
return "\n".join(lines)
from vllm.kvprune.compression.common import (
BaseCompressionMethod,
NoCompression,
)
from vllm.kvprune.compression.criticalkv import CriticalAdaKVCompression
from vllm.kvprune.compression.compactor import CompactorCompression
from vllm.kvprune.compression.compression_config import (
BatchCompressionParams,
CompressionMethod,
SequenceCompressionParams,
)
from vllm.kvprune.compression.snapkv import SnapKVCompression
COMPRESSION_REGISTRY: dict[CompressionMethod, type[BaseCompressionMethod]] = {
CompressionMethod.CRITICALADAKV: CriticalAdaKVCompression,
CompressionMethod.COMPACTOR: CompactorCompression,
CompressionMethod.SNAPKV: SnapKVCompression,
CompressionMethod.NONE: NoCompression,
}
def apply_prerope_compression(q, k, v, context):
method = context.compression_context.compression_method
return COMPRESSION_REGISTRY[method].pre_rope_scoring(q, k, v, context=context)
def apply_postrope_compression(q, k, v, prerope_scores, context):
method = context.compression_context.compression_method
return COMPRESSION_REGISTRY[method].post_rope_scoring(
q, k, v, prerope_scores, context=context
)
__all__ = [
"apply_prerope_compression",
"apply_postrope_compression",
"CompressionMethod",
"BatchCompressionParams",
"SequenceCompressionParams",
"COMPRESSION_REGISTRY"
]
from abc import ABC, abstractmethod
from typing import Optional
import torch
from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv
class BaseCompressionMethod(ABC):
"""
Abstract interface for KV cache compression methods.
A compression method is implemented as a pair of optional scoring phases
that run before and after rotary position embedding (RoPE) is applied:
1. ``pre_rope_scoring`` operates on pre-RoPE Q/K.
2. ``post_rope_scoring`` operates on post-RoPE Q/K and can either:
- refine / reweight the pre-RoPE scores, or
- compute potentially position-aware.
Concrete subclasses are expected to implement both
static methods and return a single tensor of scores (or ``None`` if the
phase is a no-op), which the caller can then feed into the shared
“scores → top-k indices → KV extraction” pipeline.
"""
@staticmethod
@abstractmethod
def pre_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
"""
Compute per-token importance scores from pre-RoPE queries/keys.
Args:
:param q:
Pre-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Pre-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
``compactor_vllm.utils.context.Context`` object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
A tensor of scores (e.g. per-token, per-head importance values)
to be passed to ``post_rope_scoring`` or directly into the
top-k selection step. If this phase is a no-op, implementations
should return ``None``. Shape ``[total_tokens, HKV]```.
"""
pass
@staticmethod
@abstractmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
"""
Compute or refine importance scores from post-RoPE queries/keys.
This method is called after rotary embeddings have been applied. It can
optionally use both the post-RoPE Q/K and any scores produced by
``pre_rope_scoring`` to produce final scores used for token selection.
Common patterns include:
* Using ``pre_rope_scores`` as a base signal and applying a
position-aware correction.
* Only computing scores that depend on absolute or relative positions.
* Simply passing through ``pre_rope_scores`` unchanged.
Args:
:param q:
Post-RoPE query tensor. Shape ``[total_tokens, HQ, D]```.
:param k:
Post-RoPE key tensor. Shape ``[total_tokens, HKV, D]```.
:param pre_rope_scores:
Optional scores returned by ``pre_rope_scoring``. May be
``None`` if the pre-RoPE phase returned None.
:param v:
Value tensor. Shape ``[total_tokens, HKV, D]```
:param context:
``compactor_vllm.utils.context.Context`` object carrying additional metadata,
such as batch mappings or temporary buffers
Returns:
:return Optional[torch.Tensor]:
Final importance scores to be consumed by the compression
pipeline (for top-k token selection). If this phase is a
no-op, implementations may return ``pre_rope_scores``. If
None is returned, no compression will be applied.
"""
pass
class NoCompression(BaseCompressionMethod):
"""
Trivial compression method that disables KV cache compression.
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
return None
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
return pre_rope_scores
def extract_and_store_top_kv(
scores: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_k_len: int,
top_k: int,
H: int,
new_keys: torch.Tensor, # [N_total, H, D]
new_vals: torch.Tensor, # [N_total, H, D]
num_tokens_to_retain: torch.Tensor, # [B] int32
page_table: torch.Tensor, # [B_total, H, N_LOGICAL_PAGES_MAX] int32
batch_mapping: torch.Tensor, # [B] int32 (local -> true batch rows)
bh_lens: torch.Tensor, # [B, H] int32 (contiguous), UPDATED atomically
k_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
v_cache: torch.Tensor, # [N_PAGES * PAGE_SIZE, D]
PAGE_SIZE: int,
PAD_TO_PAGE_SIZE: bool = True,
K_TILE: int = 16,
padding: float = -float("inf"),
):
"""helper method to extract and store top-k indices into KV cache (so they can be executed in a single stream)"""
indices_topk = scores_to_retain_indices(
scores,
cu_seqlens_k=cu_seqlens_k,
max_k_len=max_k_len,
top_k=top_k,
H=H,
padding=padding,
)
prefill_store_topk_kv(
new_keys=new_keys,
new_vals=new_vals,
indices_topk=indices_topk,
num_tokens_to_retain=num_tokens_to_retain,
page_table=page_table,
batch_mapping=batch_mapping,
bh_lens=bh_lens,
k_cache=k_cache,
v_cache=v_cache,
cu_seqlens_k=cu_seqlens_k,
PAGE_SIZE=PAGE_SIZE,
PAD_TO_PAGE_SIZE=PAD_TO_PAGE_SIZE,
K_TILE=K_TILE,
)
def scores_to_retain_indices(
scores: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_k_len: int,
top_k: int,
H: int,
padding: float = -float("inf"),
) -> torch.Tensor:
"""
Select global top-k token–head indices per sequence from packed scores.
This helper takes per-token, per-head scores in packed varlen form and
returns, for each batch element, the indices of the top-k (token, head)
pairs in the flattened global layout.
Inputs are assumed to follow the usual packed varlen convention:
• ``scores`` is laid out as ``[N_total, H]``, where:
``N_total = sum_b seqlen_k[b]``
and ``HKV`` is the number of KV heads.
• ``cu_seqlens_k`` is ``[B + 1]`` (int32), giving cumulative lengths
for the keys per batch:
``seqlen_k[b] = cu_seqlens_k[b + 1] - cu_seqlens_k[b]``.
• ``max_k_len`` is an upper bound on ``seqlen_k[b]`` across the batch.
The function pads each sequence to length ``max_k_len`` with ``padding``
(default: ``-inf``), flattens the per-sequence scores into shape
``[B, max_k_len * H]``, and runs a per-batch top-k. The returned indices
are shifted so that they directly index into the flattened global
score layout of shape ``[N_total * H]``:
global_index = (token_global_offset * H) + head_index
Args:
:param scores:
Tensor of shape ``[N_total, HKV]`` containing scores for each
(token, head) pair in packed varlen format.
:param cu_seqlens_k:
Tensor of shape ``[B + 1]`` (int32) with cumulative key sequence
lengths for each batch element. The total number of tokens
satisfies ``N_total = cu_seqlens_k[-1]``.
:param max_k_len:
Maximum key sequence length across the batch (i.e.
``max_b seqlen_k[b]``). Used to allocate the padded buffer.
:param top_k:
Number of (token, head) entries to retain **per batch element**.
If ``top_k > max_k_len * HKV``, it is clamped to ``max_k_len * HKV``.
:param H:
Number of key heads; must match ``scores.shape[1]``.
:param padding:
Padding value used when extending sequences shorter than
``max_k_len``. Defaults to ``-inf``, so that padded positions are
never selected in the top-k.
Returns:
:return torch.Tensor:
Tensor of shape ``[B, k_eff]`` (int64) where
``k_eff = min(top_k, max_k_len * H)``. Each entry is a global
index into the flattened score array of shape ``[N_total * H]``
(i.e. scores viewed as ``scores.view(-1)``),
"""
# idea: pad and then select top-k.
B, device = cu_seqlens_k.numel() - 1, scores.device
padded = torch.full(
(B, max_k_len, H), fill_value=padding, dtype=scores.dtype, device=device
)
for b in range(B):
s, e = int(cu_seqlens_k[b]), int(cu_seqlens_k[b + 1])
padded[b, : e - s, :].copy_(scores[s:e, :])
flat = padded.view(B, max_k_len * H)
idx = torch.topk(
flat, k=min(top_k, max_k_len * H), dim=1, largest=True, sorted=True
).indices
return idx + (cu_seqlens_k[:-1] * H).unsqueeze(-1)
"""
Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress``
算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、
全局 z-score、blending 与首尾 sink pad)。
非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。
"""
from __future__ import annotations
import math
from typing import List, Optional
import torch
import triton
import triton.language as tl
from transformers.models.llama.modeling_llama import repeat_kv
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.kvprune.utils.context import get_context
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
def resolve_kvpress_compactor_blending(compression_context) -> float:
"""与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。"""
if compression_context is None:
return 0.35
b = getattr(compression_context, "compactor_blending", None)
if b is not None:
return float(b)
cr = getattr(compression_context, "compression_ratio", None)
if cr is not None:
return float(cr)
return 0.35
class CompactorCompression(BaseCompressionMethod):
"""与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。"""
chunk_size: int = 256
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
# Index key rows by K packed layout (matches master/peer packed buffers).
# Do not use `or` — cu_seqlens_* are tensors and `bool(tensor)` is invalid.
_cu_k = getattr(context, "cu_seqlens_k", None)
cu_k = context.cu_seqlens_q if _cu_k is None else _cu_k
ctx = get_context()
host_k = ctx.cu_seqlens_k_host
if host_k is None:
host_k = ctx.cu_seqlens_q_host
return maybe_execute_in_stream(
kvpress_leverage_scores_packed,
k,
cu_k,
compression_context,
host_k,
STORE_STREAM=None,
)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
blending = resolve_kvpress_compactor_blending(compression_context)
return maybe_execute_in_stream(
kvpress_compactor_post_rope,
q,
k,
v,
context.cu_seqlens_q,
pre_rope_scores,
compression_context,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
blending=float(blending),
STORE_STREAM=context.STORE_STREAM,
)
# ---------------------------------------------------------------------------
# Cholesky 杠杆分(kvpress ``LeverageScorePress``)
# ---------------------------------------------------------------------------
def chol_with_jitter(
G: torch.Tensor, jitter: float = 0.0, max_tries: int = 5
) -> torch.Tensor:
identity = torch.eye(G.shape[-1], device=G.device, dtype=G.dtype)
cur = float(jitter)
for _ in range(max_tries):
L, info = torch.linalg.cholesky_ex(G + cur * identity, upper=False)
if bool((info == 0).all()):
return L
cur = max(1e-8, (1e-2 if cur == 0.0 else 10.0 * cur))
raise RuntimeError(f"Cholesky failed after {max_tries} tries.")
def compute_leverage_scores_mid(
key_states: torch.Tensor, sketch_dimension: int
) -> torch.Tensor:
"""
与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``,
返回 ``[L, H]``(未 z-score)。
维序与 kvpress 的 ``(B, H, S, D)`` 对齐:先变为 ``[1, H, L, D]``,在序列维(``dim=-2``)
上中心化,再与 ``Phi`` 为 ``(1, H, D, K)`` 的 batch 矩阵乘得到 ``[1, H, L, K]``。
"""
d, k = key_states.shape[-1], sketch_dimension
device, dtype = key_states.device, key_states.dtype
H = key_states.shape[1]
Phi = torch.randn(1, H, d, k, device=device, dtype=dtype) * (1.0 / math.sqrt(k))
# [L, H, d] -> [1, H, L, d],与 kvpress (B,H,S,d) 一致
X0 = key_states.transpose(0, 1).unsqueeze(0).contiguous()
# ROCm batched GEMM is sensitive to non-contiguous strides after transpose/mean.
X = (X0 - X0.mean(dim=-2, keepdim=True)).contiguous()
X = torch.matmul(X, Phi).to(torch.float32).contiguous()
XT = X.transpose(-2, -1).contiguous()
G = (XT @ X).contiguous()
G_sym = 0.5 * (G + G.transpose(-2, -1)).contiguous()
# HIP/ROCm: rocBLAS TRSM (used by cholesky_solve and often by linalg.solve for
# triangular solves) can launch blocks (e.g. 16x64x1) > __launch_bounds__(256).
# Small sketch_dim k: inv(G) @ XT avoids TRSM; k is typically <= 128.
if torch.version.hip is not None:
kk = G_sym.shape[-1]
eye = torch.eye(
kk, device=G_sym.device, dtype=G_sym.dtype, requires_grad=False
)
G_reg = G_sym + 1e-2 * eye
inv_Xt = torch.linalg.inv(G_reg) @ XT
else:
L = chol_with_jitter(G_sym, jitter=1e-2, max_tries=5)
inv_Xt = torch.cholesky_solve(XT, L, upper=False)
scores = (X * inv_Xt.transpose(-2, -1)).sum(dim=-1).clamp_min(0)
# [1, H, L] -> [L, H]
return scores.squeeze(0).transpose(0, 1).contiguous()
def kvpress_leverage_scores_packed(
key_states: torch.Tensor,
cu_seqlens: torch.Tensor,
compression_ctx,
cu_seqlens_host: tuple[int, ...] | None = None,
) -> torch.Tensor:
device = key_states.device
N, Hkv, _D = key_states.shape
sketch_dim = int(getattr(compression_ctx, "sketch_dimension", 48))
sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
if cu_seqlens_host is not None:
bounds = list(cu_seqlens_host)
total = bounds[-1]
else:
cu_cpu = cu_seqlens.detach().cpu().view(-1)
total = int(cu_cpu[-1])
bounds = cu_cpu.tolist()
if total != N:
raise RuntimeError(
f"kvpress_leverage_scores_packed: cu_seqlens[-1]={total} != key_states "
f"num_rows={N} (check packed prefill / TP broadcast)."
)
out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
mids_flat: list[torch.Tensor] = []
mid_ranges: list[tuple[int, int, int]] = []
for b in range(len(bounds) - 1):
k_beg = int(bounds[b])
k_end = int(bounds[b + 1])
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
k_mid = key_states[mid_start:mid_end, :, :].contiguous()
raw = compute_leverage_scores_mid(k_mid, sketch_dim)
mids_flat.append(raw.reshape(-1))
mid_ranges.append((mid_start, mid_end, Hkv))
if not mids_flat:
return out
flat = torch.cat(mids_flat, dim=0)
z = _zscore_flat_f32_global(flat)
offset = 0
for (mid_start, mid_end, _Hkv), r in zip(mid_ranges, mids_flat):
n = r.numel()
seg = z[offset : offset + n].view(mid_end - mid_start, Hkv)
out[mid_start:mid_end, :] = seg
offset += n
return out
# ---------------------------------------------------------------------------
# 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton
# ---------------------------------------------------------------------------
def _non_causal_chunked_attn_pytorch(
q: torch.Tensor, k: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""参考实现:与 kvpress 逐算子一致。"""
assert chunk_size > 0 and q.shape == k.shape
L, H, d = q.shape
B = 1
q = q.permute(1, 0, 2).unsqueeze(0).contiguous()
k = k.permute(1, 0, 2).unsqueeze(0).contiguous()
_B, H, S, _d = k.shape
S_pad = math.ceil(S / chunk_size) * chunk_size
pad_len = S_pad - S
if pad_len > 0:
q_padded = torch.cat(
[q, torch.zeros(B, H, pad_len, d, device=q.device, dtype=q.dtype)], dim=2
)
k_padded = torch.cat(
[k, torch.zeros(B, H, pad_len, d, device=k.device, dtype=k.dtype)], dim=2
)
last_chunk_start = (S // chunk_size) * chunk_size
in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
else:
q_padded, k_padded = q, k
last_chunk_start = ((S - 1) // chunk_size) * chunk_size
in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S
query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size)
num_chunks = S_pad // chunk_size
q_chunks = q_padded.view(B, H, num_chunks, chunk_size, d)
k_chunks = k_padded.view(B, H, num_chunks, chunk_size, d)
dots = torch.matmul(q_chunks, k_chunks.transpose(-2, -1))
dots[:, :, -1].masked_fill_(query_mask.unsqueeze(-1), 0)
dots[:, :, -1].masked_fill_(key_mask.unsqueeze(-2), -1e-9)
attn = torch.softmax(dots.to(torch.float32), dim=-1)
out = attn.sum(dim=-2).view(B, H, S_pad)[..., :S]
return out.squeeze(0).transpose(0, 1).contiguous()
@triton.jit
def _non_causal_chunk_row_kernel(
Q_ptr,
K_ptr,
Out_ptr,
stride_qh,
stride_qs,
stride_qd,
stride_kh,
stride_ks,
stride_kd,
stride_oh,
stride_os,
S,
S_pad,
num_chunks,
CHUNK_SIZE: tl.constexpr,
D: tl.constexpr,
BLOCK_D: tl.constexpr,
ND: tl.constexpr,
):
"""
每个 program:一个 head、一个 chunk、一条 query 行。
对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。
"""
h = tl.program_id(0)
c = tl.program_id(1)
iq = tl.program_id(2)
g_i = c * CHUNK_SIZE + iq
offs_j = tl.arange(0, CHUNK_SIZE)
logits = tl.zeros([CHUNK_SIZE], dtype=tl.float32)
for db in range(ND):
offs_d = tl.arange(0, BLOCK_D) + db * BLOCK_D
mask_d = offs_d < D
q_off = (
h * stride_qh + g_i * stride_qs + offs_d * stride_qd
)
qd = tl.load(Q_ptr + q_off, mask=mask_d, other=0.0).to(tl.float32)
g_j = c * CHUNK_SIZE + offs_j
k_row_off = h * stride_kh + g_j[:, None] * stride_ks + offs_d[None, :] * stride_kd
kj = tl.load(K_ptr + k_row_off, mask=mask_d[None, :], other=0.0).to(tl.float32)
logits += tl.sum(qd[None, :] * kj, axis=1)
row_invalid = g_i >= S
g_j_all = c * CHUNK_SIZE + offs_j
col_invalid = g_j_all >= S
logits = tl.where(row_invalid, tl.zeros([CHUNK_SIZE], dtype=tl.float32), logits)
logits = tl.where(
row_invalid,
logits,
tl.where(col_invalid, tl.full([CHUNK_SIZE], -1e-9, dtype=tl.float32), logits),
)
m = tl.max(logits)
logits = logits - m
exp_v = tl.exp(logits)
denom = tl.sum(exp_v)
p = exp_v / denom
out_base = h * stride_oh + g_j_all * stride_os
tl.atomic_add(Out_ptr + out_base, p, mask=g_j_all < S)
def _non_causal_chunked_attn_triton(
q: torch.Tensor, k: torch.Tensor, chunk_size: int
) -> torch.Tensor:
"""CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。"""
assert q.is_cuda and k.is_cuda and q.shape == k.shape
L, H, d = q.shape
assert chunk_size > 0
S_pad = math.ceil(L / chunk_size) * chunk_size
pad_len = S_pad - L
if pad_len > 0:
zq = torch.zeros(
pad_len, H, d, device=q.device, dtype=q.dtype, requires_grad=False
)
zk = torch.zeros(
pad_len, H, d, device=k.device, dtype=k.dtype, requires_grad=False
)
q = torch.cat([q, zq], dim=0)
k = torch.cat([k, zk], dim=0)
Q = q.transpose(0, 1).contiguous().to(dtype=torch.float32)
K = k.transpose(0, 1).contiguous().to(dtype=torch.float32)
num_chunks = S_pad // chunk_size
out_acc = torch.zeros(H, S_pad, device=q.device, dtype=torch.float32)
S = int(L)
grid = (H, num_chunks, chunk_size)
BLOCK_D = 32 if d <= 128 else 64
ND = (d + BLOCK_D - 1) // BLOCK_D
_non_causal_chunk_row_kernel[grid](
Q,
K,
out_acc,
Q.stride(0),
Q.stride(1),
Q.stride(2),
K.stride(0),
K.stride(1),
K.stride(2),
out_acc.stride(0),
out_acc.stride(1),
S,
S_pad,
int(num_chunks),
CHUNK_SIZE=chunk_size,
D=d,
BLOCK_D=BLOCK_D,
ND=ND,
num_warps=4,
)
return out_acc[:, :S].transpose(0, 1).contiguous()
def non_causal_chunked_attn(q: torch.Tensor, k: torch.Tensor, chunk_size: int) -> torch.Tensor:
"""q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。"""
if q.is_cuda and k.is_cuda:
return _non_causal_chunked_attn_triton(q, k, chunk_size)
return _non_causal_chunked_attn_pytorch(q, k, chunk_size)
# ---------------------------------------------------------------------------
# ×||V|| + avg_pool1d(k=3) — Triton(CUDA)
# ---------------------------------------------------------------------------
@triton.jit
def _mul_vnorm_avgpool3_kernel(
A_ptr,
V_ptr,
OUT_ptr,
stride_al,
stride_ah,
stride_vl,
stride_vh,
stride_vd,
stride_ol,
stride_oh,
L,
D: tl.constexpr,
):
"""Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。"""
l = tl.program_id(0)
h = tl.program_id(1)
offs = tl.arange(0, D)
pos_m1 = l - 1
inb_m1 = (pos_m1 >= 0) & (pos_m1 < L)
ps_m1 = tl.where(inb_m1, pos_m1, 0)
a_m1 = tl.load(
A_ptr + ps_m1 * stride_al + h * stride_ah,
mask=inb_m1,
other=0.0,
).to(tl.float32)
v_m1 = tl.load(
V_ptr + ps_m1 * stride_vl + h * stride_vh + offs * stride_vd,
mask=inb_m1,
other=0.0,
).to(tl.float32)
s_m1 = tl.where(inb_m1, a_m1 * tl.sqrt(tl.sum(v_m1 * v_m1)), 0.0)
inb_0 = (l >= 0) & (l < L)
ps0 = tl.where(inb_0, l, 0)
a0 = tl.load(
A_ptr + ps0 * stride_al + h * stride_ah,
mask=inb_0,
other=0.0,
).to(tl.float32)
v0 = tl.load(
V_ptr + ps0 * stride_vl + h * stride_vh + offs * stride_vd,
mask=inb_0,
other=0.0,
).to(tl.float32)
s_0 = tl.where(inb_0, a0 * tl.sqrt(tl.sum(v0 * v0)), 0.0)
pos_p1 = l + 1
inb_p1 = (pos_p1 >= 0) & (pos_p1 < L)
ps_p1 = tl.where(inb_p1, pos_p1, 0)
a_p1 = tl.load(
A_ptr + ps_p1 * stride_al + h * stride_ah,
mask=inb_p1,
other=0.0,
).to(tl.float32)
v_p1 = tl.load(
V_ptr + ps_p1 * stride_vl + h * stride_vh + offs * stride_vd,
mask=inb_p1,
other=0.0,
).to(tl.float32)
s_p1 = tl.where(inb_p1, a_p1 * tl.sqrt(tl.sum(v_p1 * v_p1)), 0.0)
out = (s_m1 + s_0 + s_p1) * (1.0 / 3.0)
tl.store(OUT_ptr + l * stride_ol + h * stride_oh, out)
def _mul_vnorm_avgpool3_fused(
a: torch.Tensor, v: torch.Tensor, out: torch.Tensor | None = None
) -> torch.Tensor:
assert a.dim() == 2 and v.dim() == 3 and a.shape[0] == v.shape[0] and a.shape[1] == v.shape[1]
L, H, D = v.shape
a = a.contiguous()
v = v.contiguous()
if a.dtype != torch.float32:
a = a.float()
if out is None:
out = torch.empty((L, H), device=v.device, dtype=torch.float32)
if L == 0 or H == 0:
return out
grid = (L, H)
_mul_vnorm_avgpool3_kernel[grid](
a,
v,
out,
a.stride(0),
a.stride(1),
v.stride(0),
v.stride(1),
v.stride(2),
out.stride(0),
out.stride(1),
L,
D=D,
num_warps=4,
)
return out
def _maybe_mul_vnorm_avgpool3_fused(a: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
if not a.is_cuda or not v.is_cuda:
import torch.nn.functional as F
s = a * v.norm(dim=-1)
return (
F.avg_pool1d(s.transpose(0, 1).unsqueeze(0), kernel_size=3, padding=1, stride=1)
.squeeze(0)
.transpose(0, 1)
)
return _mul_vnorm_avgpool3_fused(a, v)
@triton.jit
def _zscore_elem_1d_kernel(
X_ptr,
OUT_ptr,
n,
mean,
inv_std,
BLOCK: tl.constexpr,
):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n
x = tl.load(X_ptr + offs, mask=mask, other=0.0)
tl.store(OUT_ptr + offs, (x - mean) * inv_std, mask=mask)
def _zscore_flat_f32_global(x: torch.Tensor) -> torch.Tensor:
"""
与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。
``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。
"""
if x.numel() == 0:
return x
mu = x.mean()
sig = x.std().clamp_min(1e-6)
inv = 1.0 / sig
if not x.is_cuda:
return (x - mu) * inv
x = x.contiguous()
out = torch.empty_like(x)
n = x.numel()
BLOCK = 1024
grid = (triton.cdiv(n, BLOCK),)
_zscore_elem_1d_kernel[grid](
x,
out,
n,
float(mu.item()),
float(inv.item()),
BLOCK=BLOCK,
num_warps=4,
)
return out
def _attn_scores_kvpress_middle(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
sink_start: int,
sink_end: int,
chunk_size: int,
do_zscore: bool = True,
) -> torch.Tensor:
"""仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。"""
N, HQ, D = q.shape
Hkv = k.shape[1]
G = HQ // Hkv
device = q.device
attn_out = torch.zeros(N, Hkv, device=device, dtype=torch.float32)
parts: list[torch.Tensor] = []
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
q_m = q[mid_start:mid_end, :, :].contiguous()
k_m = k[mid_start:mid_end, :, :].contiguous()
v_m = v[mid_start:mid_end, :, :].contiguous()
# HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]``
k_4d = k_m.unsqueeze(0).transpose(1, 2).contiguous() # [1, Hkv, Lm, D]
k_rep = repeat_kv(k_4d, G)[0].transpose(0, 1).contiguous() # [Lm, HQ, D]
A = non_causal_chunked_attn(q_m, k_rep, chunk_size)
Lm, HQa = A.shape
assert HQa == HQ
A = A.view(Lm, Hkv, G).mean(dim=-1)
scores = _maybe_mul_vnorm_avgpool3_fused(A, v_m)
parts.append(scores.reshape(-1))
if not parts:
return attn_out
flat_a = torch.cat(parts, dim=0)
if do_zscore:
z_a = _zscore_flat_f32_global(flat_a)
else:
z_a = flat_a
offset = 0
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
n = (mid_end - mid_start) * Hkv
attn_out[mid_start:mid_end, :] = z_a[offset : offset + n].view(
mid_end - mid_start, Hkv
)
offset += n
return attn_out
def non_causal_attn_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_qk: torch.Tensor,
max_seqlen_qk: int,
chunk_size: int,
sm_scale: float = None,
normalize: bool = True,
context_lens: Optional[List[int]] = None,
protected_first_tokens: Optional[List[int]] = None,
protected_last_tokens: Optional[List[int]] = None,
*,
accum_scores: torch.Tensor = None,
accum_blending: float = None,
) -> torch.Tensor:
"""
与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。
``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。
然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。
"""
del sm_scale, max_seqlen_qk
sink_start, sink_end = 8, 4
out = _attn_scores_kvpress_middle(
q,
k,
v,
cu_seqlens_qk,
sink_start,
sink_end,
chunk_size,
do_zscore=normalize,
)
if accum_scores is not None:
w = 0.5 if accum_blending is None else float(accum_blending)
out = out + w * accum_scores.to(device=out.device, dtype=out.dtype)
if protected_first_tokens is not None and protected_last_tokens is not None and context_lens:
start = 0
for first, last, Lc in zip(
protected_first_tokens, protected_last_tokens, context_lens
):
out[start : start + int(first)].fill_(torch.inf)
out[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
start += int(Lc)
return out
def kvpress_compactor_post_rope(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.Tensor,
pre_rope_scores: torch.Tensor,
compression_ctx,
max_seqlen_q: int,
chunk_size: int,
blending: float,
) -> torch.Tensor:
del max_seqlen_q
Hkv = k.shape[1]
device = q.device
sink_start = int(getattr(compression_ctx, "sink_size_start", 8))
sink_end = int(getattr(compression_ctx, "sink_size_end", 4))
context_lens: Optional[List[int]] = getattr(
compression_ctx, "context_lens", None
)
protected_first: Optional[List[int]] = getattr(
compression_ctx, "protected_first_tokens", None
)
protected_last: Optional[List[int]] = getattr(
compression_ctx, "protected_last_tokens", None
)
attn_out = _attn_scores_kvpress_middle(
q, k, v, cu_seqlens, sink_start, sink_end, chunk_size
)
lev = pre_rope_scores.to(device=device, dtype=torch.float32)
blended = torch.zeros_like(lev)
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if mid_start >= mid_end:
continue
blended[mid_start:mid_end, :] = (
blending * lev[mid_start:mid_end, :] + attn_out[mid_start:mid_end, :]
)
pad_val = blended.max()
if not torch.isfinite(pad_val) or pad_val == 0:
pad_val = torch.tensor(1.0, device=device, dtype=torch.float32)
for b in range(cu_seqlens.numel() - 1):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
L = k_end - k_beg
if L == 0:
continue
left_keep = min(sink_start, L)
right_keep = min(sink_end, max(0, L - left_keep))
mid_start = k_beg + left_keep
mid_end = k_end - right_keep
if left_keep > 0:
blended[k_beg:mid_start, :] = pad_val
if right_keep > 0:
blended[mid_end:k_end, :] = pad_val
if protected_first is not None and protected_last is not None and context_lens:
start = 0
for first, last, Lc in zip(
protected_first, protected_last, context_lens
):
blended[start : start + int(first)].fill_(torch.inf)
blended[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf)
start += int(Lc)
return blended
import logging
import math
from typing import List, Optional
import torch
import triton
from tqdm.contrib.logging import logging_redirect_tqdm
from triton import language as tl
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
logger = logging.getLogger(__name__)
class CompactorCompression(BaseCompressionMethod):
chunk_size: int = 128
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
scores = maybe_execute_in_stream(
approximate_leverage_scores,
k,
compression_context.context_lens,
compression_context.PHI,
normalize=True,
chunk_size=compression_context.compression_chunk_size,
STORE_STREAM=context.STORE_STREAM,
)
return scores
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
return maybe_execute_in_stream(
non_causal_attn_scores,
q,
k,
v,
context.cu_seqlens_q,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
sm_scale=1.0,
normalize=True,
accum_scores=pre_rope_scores,
context_lens=compression_context.context_lens,
protected_first_tokens=compression_context.protected_first_tokens,
protected_last_tokens=compression_context.protected_last_tokens,
accum_blending=0.5,
)
def split_into_chunks(xs, chunk_size):
"""
Convert a list of sequence lengths into a sequence of coalesced chunk lengths.
Given an iterable of per-sequence context lengths ``xs`` and a target ``chunk_size``,
this helper produces two parallel lists:
* ``coalesced_chunks`` – lengths of contiguous segments in the
**concatenated** sequence space, where each segment corresponds either
to a full chunk of size ``chunk_size`` or to a residual "epilogue"
tail shorter than ``chunk_size``.
* ``chunks`` – the actual chunk sizes used within each original sequence.
For a length ``n``, we produce ``n // chunk_size`` entries of
``chunk_size`` (the "prologue") and at most one final entry equal to
``n % chunk_size`` (the "epilogue").
``chunks`` reflects how each input length is decomposed into
fixed-size (plus optional tail) processing blocks, while
``coalesced_chunks`` describes those same blocks after concatenating consecutive
chunks of size ``chunk_size``. together
Example:
xs = [257, 127], chunk_size = 128
coalesced_chunks = [256, 1, 127]
chunks = [128, 128, 1, 127]
Args:
:param xs:
Iterable of non-negative integers
:param chunk_size:
Target chunk size
Returns:
:return Tuple[List[int], List[int]]:
``(coalesced_chunks, chunks)`` as described above.
"""
coalesced_chunks, chunks = [], []
for n in xs:
nchunks = n // chunk_size
prologue = nchunks * chunk_size
epilogue = n - prologue
if prologue > 0:
coalesced_chunks.append(prologue)
chunks.extend([chunk_size] * nchunks)
if epilogue > 0:
coalesced_chunks.append(epilogue)
chunks.append(epilogue)
return coalesced_chunks, chunks
def approximate_leverage_scores(
key_states: torch.Tensor, # [N, H, D]
context_lens: List[int], # [B]
PHI: torch.Tensor, # [D, k]
regularizer: float = 5e-3,
normalize: bool = False,
chunk_size: int = 512,
) -> torch.Tensor: # returns [N, H]
"""
Approximate leverage scores for keys via randomized sketching.
This implements a randomized approximation to per-token leverage scores for
the key matrix, as described in Compactor: Calibrated Query-Agnostic KV Cache
Compression with Approximate Leverage Scores (https://arxiv.org/abs/2507.08143).
Args:
:param key_states:
Tensor of shape ``[N, H, D]`` containing pre-RoPE key states for
all tokens across the batch, packed along the sequence dimension.
``N = sum(context_lens)``.
:param context_lens:
List of per-sequence context lengths, length ``B``.
:param PHI:
Random projection matrix of shape ``[D, k]`` used to sketch the
keys into a lower-dimensional subspace (k < D).
:param regularizer:
Small positive scalar added to the diagonal of each Gram matrix
before SVD to improve numerical stability. Defaults to ``1e-2``.
:param normalize:
If True, apply per-sequence z-score normalization to the scores
across all heads and tokens in a batch.
:param chunk_size:
Target chunk size along the sequence dimension. If > 0, the
concatenated sequence is split into chunks of at most this size
before forming Gram matrices and SVD. If ≤ 0, the entire sequence
for each context is treated as a single chunk.
Returns:
:return torch.Tensor:
Approximate leverage scores of shape ``[N, H]``, where each row
corresponds to a token and each column to a head.
"""
if chunk_size > 0:
coalesced_chunk_lens, chunks_lens = split_into_chunks(context_lens, chunk_size)
else:
coalesced_chunk_lens, chunks_lens = context_lens, context_lens
# Same device as key_states (avoid bare .cuda() → wrong GPU in multi-device
# processes); int32 matches Triton zscore kernel expectations for cu_k.
chunk_lens_cuda = torch.tensor(
[0] + chunks_lens,
device=key_states.device,
dtype=torch.int32,
)
X = torch.matmul(key_states.transpose(0, 1), PHI)
H, N, k = X.shape
chunks = torch.split(X, coalesced_chunk_lens, dim=-2)
gram_matrices = []
for i, L in enumerate(coalesced_chunk_lens):
chunk = chunks[i]
if chunk_size <= 0 or L % chunk_size != 0:
chunk.sub_(chunk.mean(dim=-2, keepdim=True))
g = torch.matmul(chunk.transpose(-1, -2), chunk) # [H, k, k]
g = g.unsqueeze(1)
else:
chunk = chunk.view(H, -1, chunk_size, k) # [H, num_chunks, chunk_size, k]
chunk.sub_(chunk.mean(dim=-2, keepdim=True))
g = torch.matmul(chunk.transpose(-1, -2), chunk) # [H, num_chunks, k, k]
gram_matrices.append(g)
G = torch.cat(gram_matrices, dim=1).to(torch.float32)
diag = G.diagonal(dim1=-2, dim2=-1)
diag.add_(regularizer)
try:
V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
except RuntimeError:
try:
diag = G.diagonal(dim1=-2, dim2=-1)
diag.add_(regularizer * 10)
V, S, Vt = torch.linalg.svd(G, full_matrices=False, driver="gesvda")
except RuntimeError:
with logging_redirect_tqdm():
logger.warning(
"GESVDA failed, falling back to QR decomposition, which will be MUCH slower. "
"Try increasing chunk_size if this issue persists."
)
# this is over 50 times slower than using GESVDA
return _approximate_leverage_scores_qr_fallback(
X=X,
chunks_lens=chunks_lens,
chunk_lens_cuda=chunk_lens_cuda,
normalize=normalize,
chunk_size=chunk_size,
)
SV = (V * S.rsqrt().unsqueeze(-2)).to(X.dtype)
start = 0
all_scores = []
for i, L in enumerate(coalesced_chunk_lens):
chunk = chunks[i]
if chunk_size <= 0 or L % chunk_size != 0:
num_chunks = 1
sv = SV[:, start]
else:
num_chunks = L // chunk_size
chunk = chunk.view(H, -1, chunk_size, k) # [H, NC, CS]
sv = SV[:, start : start + num_chunks]
U = torch.matmul(chunk, sv)
scores = (U * U).sum(dim=-1).clamp_min_(0.0).view(H, -1)
all_scores.append(scores.transpose(-1, -2))
start += num_chunks
scores = torch.cat(all_scores, dim=0)
if normalize:
grid = (len(chunks_lens),)
cu_k = chunk_lens_cuda.cumsum(dim=0)
_zscore_per_batch_epilogue_no_window[grid](
scores, cu_k, scores.stride(0), scores.stride(1), H
)
return scores
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
cache_results=True,
)
@triton.jit
def _zscore_per_batch_epilogue_no_window(
OUT, # [Nk, Hk], float32
cu_k, # [B+1] int32
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr, # Hk
BLOCK_K: tl.constexpr, # e.g., 128
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
if k_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var)
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def _approximate_leverage_scores_qr_fallback(
X: torch.Tensor, # [H, N, k], already sketched (KΦ) and centered in-place
chunks_lens: List[int], # [num_chunks]
chunk_lens_cuda: torch.Tensor, # [num_chunks + 1] (prefix base)
normalize: bool,
chunk_size: int,
) -> torch.Tensor:
H, N, k = X.shape
device, dtype = X.device, X.dtype
offsets: List[int] = []
offset = 0
for L in chunks_lens:
offsets.append(offset)
offset += L
if offset != N:
raise RuntimeError(
f"QR fallback: sum(chunks_lens)={offset} does not match N={N}"
)
blocks = torch.split(X, chunks_lens, dim=-2)
scores = torch.empty(N, H, device=device, dtype=dtype)
if chunk_size > 0:
full_indices = [i for i, L in enumerate(chunks_lens) if L == chunk_size]
epi_indices = [i for i, L in enumerate(chunks_lens) if L != chunk_size]
if full_indices:
# stack full chunks
full_blocks = torch.stack(
[blocks[i] for i in full_indices], dim=0
) # [M, H, CS, k]
M, Hf, Lf, kf = full_blocks.shape
assert Lf == chunk_size
# merge (M, H) into a single batch dim for torch.linalg.q
full_blocks_2d = full_blocks.view(M * Hf, Lf, kf).to(torch.float32)
U_full, _ = torch.linalg.qr(full_blocks_2d, mode="reduced")
U_full = U_full.to(dtype)
scores_full = (U_full * U_full).sum(dim=-1).clamp_min(0.0) # [M * Hf, Lf]
scores_full = scores_full.view(M, Hf, Lf).transpose(-1, -2) # [M, H, CS]
for m, chunk_idx in enumerate(full_indices):
start = offsets[chunk_idx]
Lc = chunks_lens[chunk_idx]
scores[start : start + Lc].copy_(scores_full[m])
else:
epi_indices = list(range(len(chunks_lens)))
for chunk_idx in epi_indices:
block = blocks[chunk_idx]
_, Lc, _ = block.shape
if Lc == 0:
continue
U_epi, _ = torch.linalg.qr(block.to(torch.float32), mode="reduced")
scores_epi = (U_epi * U_epi).sum(dim=-1).to(dtype) # [H, Lc]
start = offsets[chunk_idx]
scores[start : start + Lc] = scores_epi.transpose(0, 1) # [Lc, H]
if normalize:
grid = (len(chunks_lens),)
cu_k = chunk_lens_cuda.cumsum(dim=0)
_zscore_per_batch_epilogue_no_window[grid](
scores, cu_k, scores.stride(0), scores.stride(1), H
)
return scores
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_M": BM, "BLOCK_K": BK, "WARPSPEC": False}, num_warps=w, num_stages=s
)
for BM in [64]
for BK in [64]
for w in [4]
for s in [2]
],
key=[
"QUERY_GROUP_SIZE",
"D",
"CHUNK_SIZE",
],
cache_results=True,
)
@triton.jit
def _non_causal_attn_kernel(
Q,
K,
V,
accum_scores,
cu_seqlens_qk,
#
STRIDE_Q_G,
STRIDE_Q_N,
STRIDE_Q_H,
STRIDE_Q_D,
STRIDE_K_G,
STRIDE_K_N,
STRIDE_K_D,
STRIDE_V_G,
STRIDE_V_N,
STRIDE_V_D,
STRIDE_OUT_N,
STRIDE_OUT_H,
sm_scale,
#
CHUNK_SIZE: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_K: tl.constexpr,
D: tl.constexpr,
WARPSPEC: tl.constexpr,
):
TOTAL_QUERIES_PER_BLOCK: tl.constexpr = BLOCK_M * QUERY_GROUP_SIZE
INVERSE_CHUNK: tl.constexpr = 1.0 / CHUNK_SIZE
pid_g = tl.program_id(0) # KV head in [0, HKV)
pid_b = tl.program_id(1) # batch id
pid_m = tl.program_id(2) # chunk id within batch
off_b = tl.load(cu_seqlens_qk + pid_b)
off_b1 = tl.load(cu_seqlens_qk + pid_b + 1)
chunk_start = off_b + pid_m * CHUNK_SIZE
chunk_end = tl.minimum(chunk_start + CHUNK_SIZE, off_b1)
M = chunk_end - chunk_start
if M <= 0:
return
offs_d = tl.arange(0, D)
offs_k = tl.arange(0, BLOCK_K)
# Flattened query rows inside a [BLOCK_M, QUERY_GROUP_SIZE] tile
offs_q = tl.arange(0, TOTAL_QUERIES_PER_BLOCK)
row_m = offs_q % BLOCK_M # token offset in this tile
row_h = offs_q // BLOCK_M # query-group index
qk_scale = sm_scale * 1.44269504 # convert to log2-domain
NEG_INF = -1.0e9
# Iterate over query tiles within this chunk
for qs in tl.range(chunk_start, chunk_end, BLOCK_M):
# Global query indices for rows in this tile
q_idx = qs + row_m # [TOTAL_QUERIES_PER_BLOCK]
q_mask = q_idx < chunk_end # mask for valid rows in this tile
# Load Q tile: [TOTAL_QUERIES_PER_BLOCK, D]
q_ptrs = (
Q
+ pid_g * STRIDE_Q_G
+ q_idx[:, None] * STRIDE_Q_N
+ row_h[:, None] * STRIDE_Q_H
+ offs_d[None, :] * STRIDE_Q_D
)
q = tl.load(q_ptrs, mask=q_mask[:, None], other=0.0)
# ---- Pass 1: per-row max and denominator over all keys in this chunk ----
row_max = tl.full([TOTAL_QUERIES_PER_BLOCK], NEG_INF, tl.float32)
row_sum = tl.zeros([TOTAL_QUERIES_PER_BLOCK], dtype=tl.float32)
for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
k_idx = ks + offs_k # [BLOCK_K]
k_mask = k_idx < chunk_end # which keys are valid in this tile
k_ptrs = (
K
+ pid_g * STRIDE_K_G
+ k_idx[:, None] * STRIDE_K_N
+ offs_d[None, :] * STRIDE_K_D
)
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0) # [BLOCK_K, D]
# logits: [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
qk = tl.dot(q, k.T) * qk_scale
qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
cur_max = tl.max(qk, 1)
new_max = tl.maximum(row_max, cur_max)
# rescale previous sum to new_max (base 2)
rescale = tl.math.exp2(row_max - new_max)
p = tl.math.exp2(qk - new_max[:, None])
row_sum = row_sum * rescale + tl.sum(p, 1)
row_max = new_max
# Avoid division by zero for inactive rows
denom = tl.where(q_mask, row_sum, 1.0)
for ks in tl.range(chunk_start, chunk_end, BLOCK_K):
k_idx = ks + offs_k
k_mask = k_idx < chunk_end
k_ptrs = (
K
+ pid_g * STRIDE_K_G
+ k_idx[:, None] * STRIDE_K_N
+ offs_d[None, :] * STRIDE_K_D
)
k = tl.load(k_ptrs, mask=k_mask[:, None], other=0.0)
qk = tl.dot(q, k.T) * qk_scale
qk = tl.where(q_mask[:, None] & k_mask[None, :], qk, NEG_INF)
# p has shape [TOTAL_QUERIES_PER_BLOCK, BLOCK_K]
p = tl.math.exp2(qk - row_max[:, None]) / denom[:, None]
# zero-out invalid rows / columns
p = tl.where(
q_mask[:, None], p, INVERSE_CHUNK
) # preserve attention mass in shorter chunks
contrib = tl.sum(p, 0) # [BLOCK_K], sum over queries & query-groups
out_ptrs = accum_scores + k_idx * STRIDE_OUT_N + pid_g * STRIDE_OUT_H
old = tl.load(out_ptrs, mask=k_mask, other=0.0)
new = old + contrib.to(old.dtype)
tl.store(out_ptrs, new, mask=k_mask)
def non_causal_attn_scores(
q: torch.Tensor, # [N, HQ, D]
k: torch.Tensor, # [N, HKV, D]
v: torch.Tensor, # [N, HKV, D]
cu_seqlens_qk: torch.Tensor, # [B + 1]
max_seqlen_qk: int,
chunk_size: int,
sm_scale: float = None,
normalize: bool = True,
context_lens: Optional[List[int]] = None,
protected_first_tokens: Optional[List[int]] = None,
protected_last_tokens: Optional[List[int]] = None,
*,
accum_scores: torch.Tensor = None, # [N, HKV] (float32)
accum_blending: float = None,
) -> torch.Tensor:
"""
:param q: Tensor of shape ``[N, H, D]`` containing post-rope queries
:param k: Tensor of shape ``[N, H, D]`` containing post-rope keys
:param v: Tensor of shape ``[N, H, D]`` containing values
:param cu_seqlens_qk Tensor of shape ``[B + 1]`` demarcating batch boundaries
:param max_seqlen_qk int containing the maximum sequence length
:param chunk_size: int specifying the size of the chunk to perform non-causal attention over
:param sm_scale: float specifying the scaling factor applied to attention scores (1/sqrt(D) if None)
:param normalize: bool specifying whether to z-score normalize final attention scores
:param context_lens: List[int] specifying the context lengths. CPU version of cu_seqlens_qk.diff(0)
:param protected_first_tokens: List[int] specifying how many tokens should be protected at the
start of each sequence
:param protected_last_tokens: List[int] specifying how many tokens should be protected at the
end of each sequence
:param accum_scores: Tensor of shape ``[N, H]`` containing key scores that should be accumulated into
:param accum_blending float specifying the scaling of ``accum_scores`` prior to adding the new
non-causal attention scores. Final output is equivalent to return out + accum_blending * accum_scores
"""
assert q.ndim == 3 and k.ndim == 3
assert q.shape[0] == k.shape[0] and q.shape[-1] == k.shape[-1]
N, HQ, D = q.shape
HKV = k.shape[1]
assert HQ % HKV == 0, "Number of query heads must divide number of KV heads"
assert (D & (D - 1)) == 0, "D must be a power of two"
B = cu_seqlens_qk.numel() - 1
H_g = HQ // HKV # query-group size per KV head
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
out = torch.zeros(N, HKV, device=q.device, dtype=torch.float32)
q = q.view(N, HKV, H_g, D).permute(1, 0, 2, 3)
k = k.view(N, HKV, D).permute(1, 0, 2)
# v = v.view(N, HKV, D).permute(1, 0, 2)
if cu_seqlens_qk.device != q.device:
cu_seqlens_qk = cu_seqlens_qk.to(device=q.device)
cu_seqlens_qk = cu_seqlens_qk.to(torch.int32)
STRIDE_Q_G, STRIDE_Q_N, STRIDE_Q_H, STRIDE_Q_D = q.stride()
STRIDE_K_G, STRIDE_K_N, STRIDE_K_D = k.stride()
STRIDE_V_G, STRIDE_V_N, STRIDE_V_D = v.stride()
STRIDE_OUT_N, STRIDE_OUT_H = out.stride()
assert STRIDE_Q_D == 1 and STRIDE_K_D == 1, "last dim must be contiguous"
def grid(_):
return (
HKV,
B,
triton.cdiv(max_seqlen_qk, chunk_size),
)
_non_causal_attn_kernel[grid](
q,
k,
v,
out,
cu_seqlens_qk,
STRIDE_Q_G,
STRIDE_Q_N,
STRIDE_Q_H,
STRIDE_Q_D,
STRIDE_K_G,
STRIDE_K_N,
STRIDE_K_D,
STRIDE_V_G,
STRIDE_V_N,
STRIDE_V_D,
STRIDE_OUT_N,
STRIDE_OUT_H,
sm_scale,
CHUNK_SIZE=chunk_size,
QUERY_GROUP_SIZE=H_g,
D=D,
)
if normalize:
grid = (B,)
_zscore_per_batch_epilogue_no_window[grid](
out, cu_seqlens_qk, out.stride(0), out.stride(1), HKV
)
if accum_scores is not None:
if accum_blending is not None:
out += accum_scores * accum_blending
else:
out += accum_scores
if protected_first_tokens is not None or protected_last_tokens is not None:
start = 0
for first, last, L in zip(
protected_first_tokens, protected_last_tokens, context_lens
):
out[start : start + first].fill_(torch.inf)
out[start + L - last : start + L].fill_(torch.inf)
start += L
return out
import logging
from dataclasses import dataclass
from enum import Enum, auto
logger = logging.getLogger(__name__)
class CompressionMethod(Enum):
CRITICALADAKV = auto()
COMPACTOR = auto()
SNAPKV = auto()
NONE = auto()
# class CachingPolicy(Enum):
# CACHE_PROMPT = auto()
# DONT_CACHE = auto()
# class CompressionType(Enum):
# QUERY_AWARE = auto()
# QUERY_AGNOSTIC = auto()
@dataclass
class SequenceCompressionParams:
compression_ratio: float = 1.0
protected_first_tokens: int = 16
protected_last_tokens: int = 64
@dataclass
class BatchCompressionParams:
# compression_type: CompressionType = CompressionType.QUERY_AGNOSTIC
compression_method: CompressionMethod = CompressionMethod.COMPACTOR
do_chunked_compression: bool = True
chunk_size: int = 512
def __post_init__(self):
if self.compression_method == CompressionMethod.SNAPKV:
self.do_chunked_compression = False
logger.warning(
"CompressionMethod.SNAPKV is not compatible with chunked compression. Disabling it."
)
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)及首/尾保护段长度。
注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
import torch
import triton
from triton import language as tl
from compactor_vllm.compression.common import BaseCompressionMethod
from compactor_vllm.compression.compactor import (
CompactorCompression,
non_causal_attn_scores,
)
from compactor_vllm.compression.snapkv import SnapKVCompression
from compactor_vllm.utils.helpers import maybe_execute_in_stream
from compactor_vllm.utils.triton_compat import autotune as triton_autotune
# ============================================================================
# Triton Kernel 1: 计算 ||Wo @ V||₁ (L1 范数)
# ============================================================================
@triton_autotune(
configs=[
triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
for bk in [32, 64, 128]
for bd in [32, 64]
for nw in [4, 8]
for ns in [3, 4]
],
key=["Hk", "D", "HIDDEN"],
cache_results=True,
)
@triton.jit
def _compute_wo_v_l1_kernel(
V,
WO,
cu_k,
OUT,
STRIDE_V_NK,
STRIDE_V_HK,
STRIDE_V_D,
STRIDE_WO_HQ,
STRIDE_WO_D,
STRIDE_WO_HID,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
Hq: tl.constexpr,
D: tl.constexpr,
HIDDEN: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
ks = tl.program_id(2)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
nk = k_beg + nk_off
k_mask = nk < k_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
for g in range(QUERY_GROUP_SIZE):
hq = hk * QUERY_GROUP_SIZE + g
v_ptrs = (
V
+ nk[:, None] * STRIDE_V_NK
+ hk * STRIDE_V_HK
+ tl.arange(0, D)[None, :] * STRIDE_V_D
)
v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
for hid_off in range(0, HIDDEN, BLOCK_D):
hid_idx = hid_off + tl.arange(0, BLOCK_D)
hid_mask = hid_idx < HIDDEN
wo_ptrs = (
WO
+ hq * STRIDE_WO_HQ
+ tl.arange(0, D)[:, None] * STRIDE_WO_D
+ hid_idx[None, :] * STRIDE_WO_HID
)
wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
wov_tile = tl.dot(v_blk, wo_tile)
l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
l1_sum = l1_sum / QUERY_GROUP_SIZE
tl.store(out_ptrs, l1_sum, mask=k_mask)
# ============================================================================
# Triton Kernel 2: Stage 1 保护 + Stage 2 加权融合
# ============================================================================
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
key=["Hk"],
cache_results=True,
)
@triton.jit
def _critical_ada_fuse_kernel(
BASE_SCORES,
WO_V_NORM,
STAGE1_MASK,
cu_k,
OUT,
EPSILON: tl.constexpr,
STRIDE_BS_NK,
STRIDE_BS_HK,
STRIDE_WN_NK,
STRIDE_WN_HK,
STRIDE_S1_NK,
STRIDE_S1_HK,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
base = tl.load(bs_ptrs, mask=kmask, other=0.0)
wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
fused = (base + EPSILON) * wnorm
fused = tl.where(stage1_protect == 1, float("inf"), fused)
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, fused, mask=kmask)
def critical_ada_key_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
wo_weight: torch.Tensor,
cu_seqlens: torch.Tensor,
base_scores: torch.Tensor,
compression_ctx: Any,
*,
store_stream: Optional[torch.cuda.Stream] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
在每条序列上尽量贴近 kvpress 的 CriticalAdaKV 语义:
1) alpha_safeguard 安全预算(每头至少保留一部分);
2) 基于 base_scores 的 head-wise 自适应预算分配(head_budgets);
3) Stage-1 按 head_budgets * first_stage_ratio 保护;
4) Stage-2 计算 ``(base + eps) * ||Wo@V||_1``,再按 head_budgets 做每头 top-k 保护。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``、``protected_first_tokens``、``protected_last_tokens``;
可选 ``critical_ada_epsilon``、``critical_ada_first_stage_ratio``、
``critical_ada_alpha_safeguard``。
"""
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
device = q.device
_, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert D == Dk and Hq % Hk == 0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B = cu_seqlens.numel() - 1
G = Hq // Hk
k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
btr = compression_ctx.batch_tokens_to_retain
assert btr is not None and btr.numel() == B
btr = btr.to(device=device, dtype=torch.int32)
prot_first = compression_ctx.protected_first_tokens or [0] * B
prot_last = compression_ctx.protected_last_tokens or [0] * B
epsilon = compression_ctx.critical_ada_epsilon
first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
alpha_safeguard = float(getattr(compression_ctx, "critical_ada_alpha_safeguard", 0.2))
alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
if wo_weight.dim() == 2:
hidden_size, _ = wo_weight.shape
wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
else:
wo = wo_weight.contiguous()
hidden_size = wo.size(-1)
wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
def grid_wo(META):
max_k_len = int(k_lengths.max().item())
return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
_compute_wo_v_l1_kernel[grid_wo](
v,
wo,
cu_seqlens,
wo_v_norm,
*v.stride(),
*wo.stride(),
*wo_v_norm.stride(),
Hk=Hk,
Hq=Hq,
D=D,
HIDDEN=hidden_size,
QUERY_GROUP_SIZE=G,
)
stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
# kvpress 风格的每头预算(按序列自适应),用于 Stage-1/Stage-2。
head_budgets_by_batch = []
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
head_budgets_by_batch.append(None)
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
s = int(prot_first[b]) if b < len(prot_first) else 0
e = int(prot_last[b]) if b < len(prot_last) else 0
lo, hi = k_beg + s, k_end - e
compressible = max(0, hi - lo)
keep_pairs = int(btr[b].item())
if compressible <= 0:
head_budgets_by_batch.append(None)
continue
# 每头 token 预算(kvpress 的 n_kept)
n_kept_tokens = max(1, keep_pairs // Hk)
n_kept_tokens = min(n_kept_tokens, compressible)
# 安全预算(每头至少保留 n_safe)
n_safe = int(n_kept_tokens * alpha_safeguard)
if n_safe > 0:
tk_safe = min(n_safe, compressible)
for hk in range(Hk):
safe_idx = torch.topk(base_scores[lo:hi, hk], tk_safe, sorted=False).indices
stage1_mask[lo + safe_idx, hk] = 1
# 自适应预算分配:在扁平 (token, head) 空间取 top n_kept_tokens*Hk,统计每个 head 的预算
budget_scores = base_scores[lo:hi, :].clone()
if n_safe > 0:
budget_scores[stage1_mask[lo:hi, :] == 1] = float("inf")
top_pairs = min(n_kept_tokens * Hk, budget_scores.numel())
if top_pairs <= 0:
head_budgets_by_batch.append(None)
continue
top_idx_flat = torch.topk(
budget_scores.reshape(-1), top_pairs, sorted=False
).indices
top_head_idx = top_idx_flat % Hk
head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
head_budgets_by_batch.append(head_budgets)
# Stage-1:按 head_budgets 的 first_stage_ratio 分头保护(kvpress 语义)
for hk in range(Hk):
phase1_budget = int(head_budgets[hk].item() * first_stage_ratio)
if phase1_budget <= 0:
continue
tk = min(phase1_budget, compressible)
top_idx = torch.topk(base_scores[lo:hi, hk], tk, sorted=False).indices
stage1_mask[lo + top_idx, hk] = 1
final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
def grid_fuse(_META):
return (B, Hk)
_critical_ada_fuse_kernel[grid_fuse](
base_scores,
wo_v_norm,
stage1_mask,
cu_seqlens,
final_scores,
EPSILON=epsilon,
*base_scores.stride(),
*wo_v_norm.stride(),
*stage1_mask.stride(),
*final_scores.stride(),
Hk=Hk,
)
# Stage-2(kvpress 语义):在融合后按每头预算再做一次 top-k 保护。
for b in range(B):
hb = head_budgets_by_batch[b]
if hb is None:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
s = int(prot_first[b]) if b < len(prot_first) else 0
e = int(prot_last[b]) if b < len(prot_last) else 0
lo, hi = k_beg + s, k_end - e
if hi <= lo:
continue
region_len = hi - lo
for hk in range(Hk):
budget = int(hb[hk].item())
if budget <= 0:
continue
tk = min(budget, region_len)
idx = torch.topk(final_scores[lo:hi, hk], tk, sorted=False).indices
final_scores[lo + idx, hk] = float("inf")
masked_key_indices = None
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
continue
keep_pairs = int(btr[b].item())
total_pairs = k_len * Hk
if keep_pairs >= total_pairs:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
if n_prune_pairs <= 0:
continue
flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
prune_idx = torch.topk(
-flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
).indices
batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
head_idx = prune_idx % Hk
seq_idx = prune_idx // Hk + k_beg
if masked_key_indices is None:
masked_key_indices = (batch_idx, head_idx, seq_idx)
else:
masked_key_indices = (
torch.cat([masked_key_indices[0], batch_idx]),
torch.cat([masked_key_indices[1], head_idx]),
torch.cat([masked_key_indices[2], seq_idx]),
)
if store_stream is not None:
final_scores.record_stream(store_stream)
return final_scores, masked_key_indices
class CriticalAdaKVCompression(BaseCompressionMethod):
"""
以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
cc = context.compression_context
base = getattr(cc, "critical_ada_base_scorer", "compactor") if cc is not None else "compactor"
if str(base).lower() == "snapkv":
return SnapKVCompression.pre_rope_scoring(q, k, v, context)
return CompactorCompression.pre_rope_scoring(q, k, v, context)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
assert compression_context is not None
base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
if base == "snapkv":
base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
else:
# 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
# maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
# 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
base_scores = maybe_execute_in_stream(
non_causal_attn_scores,
q,
k,
v,
context.cu_seqlens_q,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
sm_scale=1.0,
normalize=True,
accum_scores=pre_rope_scores,
context_lens=compression_context.context_lens,
protected_first_tokens=compression_context.protected_first_tokens,
protected_last_tokens=compression_context.protected_last_tokens,
accum_blending=0.5,
)
wo_weight = compression_context.wo_weight
if wo_weight is None:
return base_scores
scores, _masked = maybe_execute_in_stream(
critical_ada_key_scores,
q,
k,
v,
wo_weight,
context.cu_seqlens_q,
base_scores,
compression_context,
STORE_STREAM=context.STORE_STREAM,
store_stream=context.STORE_STREAM,
)
return scores
@staticmethod
def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if not hasattr(module, "o_proj") or module.o_proj.weight is None:
return
if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
return
wo_raw = module.o_proj.weight.data
hidden_size, _ = wo_raw.shape
Hq = module.num_heads
head_dim = module.head_dim
wo = (
wo_raw.transpose(0, 1)
.view(Hq, head_dim, hidden_size)
.to(device=device, dtype=torch.float32)
)
module._critical_ada_wo_weight = wo
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)。CriticalAda 主链在 **PyTorch** 中与 kvpress ``CriticalAdaKVPress.compress``
对齐;``||Wo@V||_1`` 仍默认用 Triton ``_compute_wo_v_l1_kernel``(与 ``CriticalKVPress.vwl1norm`` 同式)。
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 可改走 ``_vwl1_norm_kvpress_reference``。
注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
import torch
import triton
from triton import language as tl
from transformers.models.llama.modeling_llama import repeat_kv
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.kvprune.compression.compactor import (
CompactorCompression,
kvpress_compactor_post_rope,
resolve_kvpress_compactor_blending,
)
from vllm.kvprune.compression.snapkv import SnapKVCompression
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
def _criticalkv_prune_hip_pipeline(configs, _, **kwargs):
"""HIP: TritonHCUGPUStreamPipelineV2 breaks on nested loops + hid_idx arange (see snapkv)."""
if torch.version.hip is None:
return list(configs)
return [c for c in configs if getattr(c, "num_stages", 1) == 1]
def _compute_wo_v_l1_autotune_configs():
"""CUDA: full autotune. HIP: single num_stages=1 config (avoids pipeliner + long autotune)."""
if torch.version.hip is not None:
return [
triton.Config(
{"BLOCK_K": 64, "BLOCK_D": 64}, num_warps=4, num_stages=1
),
]
return [
triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
for bk in [32, 64, 128]
for bd in [32, 64]
for nw in [4, 8]
for ns in [3, 4]
]
# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
_USE_WO_L1_REFERENCE_BACKEND = False
def _vwl1_norm_kvpress_reference(
values_seg: torch.Tensor,
wo: torch.Tensor,
num_kv_heads: int,
num_query_groups: int,
) -> torch.Tensor:
"""
与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
"""
k_len, Hk, D = values_seg.shape
Hq, D_wo, hidden = wo.shape
assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
# [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
# Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
wo_f = wo
head_list = []
for head in range(Hq):
v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
head_wov = v_h.matmul(wo_f[head, :, :])
head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
head_list.append(head_wov_norm)
stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
return stacked.transpose(0, 1).contiguous()
# ============================================================================
# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
# ============================================================================
@triton_autotune(
configs=_compute_wo_v_l1_autotune_configs(),
key=["Hk", "D", "HIDDEN"],
cache_results=True,
prune_configs_by={"early_config_prune": _criticalkv_prune_hip_pipeline},
)
@triton.jit
def _compute_wo_v_l1_kernel(
V,
WO,
cu_k,
OUT,
STRIDE_V_NK,
STRIDE_V_HK,
STRIDE_V_D,
STRIDE_WO_HQ,
STRIDE_WO_D,
STRIDE_WO_HID,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
Hq: tl.constexpr,
D: tl.constexpr,
HIDDEN: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
b = tl.program_id(0)
hk = tl.program_id(1)
ks = tl.program_id(2)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
nk = k_beg + nk_off
k_mask = nk < k_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
for g in range(QUERY_GROUP_SIZE):
hq = hk * QUERY_GROUP_SIZE + g
v_ptrs = (
V
+ nk[:, None] * STRIDE_V_NK
+ hk * STRIDE_V_HK
+ tl.arange(0, D)[None, :] * STRIDE_V_D
)
v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
for hid_off in range(0, HIDDEN, BLOCK_D):
hid_idx = hid_off + tl.arange(0, BLOCK_D)
hid_mask = hid_idx < HIDDEN
wo_ptrs = (
WO
+ hq * STRIDE_WO_HQ
+ tl.arange(0, D)[:, None] * STRIDE_WO_D
+ hid_idx[None, :] * STRIDE_WO_HID
)
wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
wov_tile = tl.dot(v_blk, wo_tile)
l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
l1_sum = l1_sum / QUERY_GROUP_SIZE
tl.store(out_ptrs, l1_sum, mask=k_mask)
def critical_ada_key_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
wo_weight: torch.Tensor,
cu_seqlens: torch.Tensor,
base_scores: torch.Tensor,
compression_ctx: Any,
*,
store_stream: Optional[torch.cuda.Stream] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
按 kvpress ``CriticalAdaKVPress.compress`` 的顺序实现:safeguard scatter →
head-major 展平做 head_budgets → Stage1 在 **已抬高** 的分数上 top-k →
``(scores + ε) * ||WoV||₁`` → Stage2 scatter → 最终按 head-major 展平做 bottom-k。
``||Wo@V||₁`` 仍用 Triton(``_compute_wo_v_l1_kernel``);中间 CriticalAda 步骤用 PyTorch
与 kvpress 逐句对齐。仅 base 分数来自 Compactor/SnapKV。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
"""
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
device = q.device
_, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert D == Dk and Hq % Hk == 0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B = cu_seqlens.numel() - 1
G = Hq // Hk
k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
btr = compression_ctx.batch_tokens_to_retain
assert btr is not None and btr.numel() == B
btr = btr.to(device=device, dtype=torch.int32)
epsilon = compression_ctx.critical_ada_epsilon
first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
if wo_weight.dim() == 2:
hidden_size, _ = wo_weight.shape
wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
else:
wo = wo_weight.contiguous()
hidden_size = wo.size(-1)
wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
if B > 0 and int(k_lengths.max().item()) > 0:
if _USE_WO_L1_REFERENCE_BACKEND:
for b in range(B):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
if k_end <= k_beg:
continue
v_seg = v[k_beg:k_end, :, :].contiguous()
wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
v_seg, wo, Hk, G
)
else:
def grid_wo(META):
max_k_len = int(k_lengths.max().item())
return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
_compute_wo_v_l1_kernel[grid_wo](
v,
wo,
cu_seqlens,
wo_v_norm,
*v.stride(),
*wo.stride(),
*wo_v_norm.stride(),
Hk=Hk,
Hq=Hq,
D=D,
HIDDEN=hidden_size,
QUERY_GROUP_SIZE=G,
)
# kvpress 用 finfo.max 抬高分数;与 inf 混用时 topk 行为一致
_score_max = float(torch.finfo(torch.float32).max)
final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
head_budgets_by_batch: list[Optional[torch.Tensor]] = []
for b in range(B):
k_len = int(k_lengths[b].item())
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
if k_len == 0:
head_budgets_by_batch.append(None)
continue
scores_seg = base_scores[k_beg:k_end, :].float()
keep_pairs = int(btr[b].item())
n_kept_tokens = max(1, keep_pairs // Hk)
n_kept_tokens = min(n_kept_tokens, k_len)
# scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
scores_work = scores_seg.clone()
# --- Alpha safeguard(kvpress L148–152)---
n_safe = int(n_kept_tokens * alpha_safeguard)
nk = min(n_safe, k_len) if n_safe > 0 else 0
if nk > 0:
for hk in range(Hk):
top_idx = torch.topk(scores_work[:, hk], nk, dim=0, largest=True).indices
scores_work[top_idx, hk] = _score_max
# --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
top_pairs = min(n_kept_tokens * Hk, k_len * Hk)
if top_pairs <= 0:
head_budgets_by_batch.append(None)
wn = wo_v_norm[k_beg:k_end, :]
final_scores[k_beg:k_end, :] = (scores_seg + epsilon) * wn
continue
budget_flat = scores_work.permute(1, 0).contiguous().reshape(-1)
top_idx_flat = torch.topk(
budget_flat, top_pairs, largest=True, sorted=False
).indices
top_head_idx = top_idx_flat // k_len
head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int64)
head_budgets_by_batch.append(head_budgets)
# --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
head_selection_budget_1st = (
(head_budgets.to(torch.float32) * float(first_stage_ratio))
.to(torch.int64)
.tolist()
)
M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
mk = min(M1, k_len) if M1 > 0 else 0
if mk > 0:
top_k_index = torch.topk(scores_work, mk, dim=0, largest=True, sorted=True).indices
for hk in range(Hk):
phase1_budget = int(head_selection_budget_1st[hk])
if phase1_budget <= 0:
continue
take = min(phase1_budget, mk)
scores_work[top_k_index[:take, hk], hk] = _score_max
# --- Stage 2 重加权(kvpress L173–175)---
wn = wo_v_norm[k_beg:k_end, :]
scores_fused = (scores_work + epsilon) * wn
# --- Stage 2 scatter(kvpress L176–179)---
M2 = int(head_budgets.max().item())
mk2 = min(M2, k_len) if M2 > 0 else 0
if mk2 > 0:
top_k_index2 = torch.topk(
scores_fused, mk2, dim=0, largest=True, sorted=True
).indices
for hk in range(Hk):
budget = int(head_budgets[hk].item())
if budget <= 0:
continue
take = min(budget, mk2)
scores_fused[top_k_index2[:take, hk], hk] = _score_max
final_scores[k_beg:k_end, :] = scores_fused
masked_key_indices = None
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
continue
keep_pairs = int(btr[b].item())
total_pairs = k_len * Hk
if keep_pairs >= total_pairs:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
if n_prune_pairs <= 0:
continue
# kvpress L187:``scores.reshape(bsz, -1)`` 即 [H, K] 按 head-major 展平(flat = h*K + t)
flat_scores = (
final_scores[k_beg:k_end, :].permute(1, 0).contiguous().reshape(-1)
)
prune_idx = torch.topk(
-flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
).indices
batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
head_idx = prune_idx // k_len
seq_idx = prune_idx % k_len + k_beg
if masked_key_indices is None:
masked_key_indices = (batch_idx, head_idx, seq_idx)
else:
masked_key_indices = (
torch.cat([masked_key_indices[0], batch_idx]),
torch.cat([masked_key_indices[1], head_idx]),
torch.cat([masked_key_indices[2], seq_idx]),
)
if store_stream is not None:
final_scores.record_stream(store_stream)
return final_scores, masked_key_indices
class CriticalAdaKVCompression(BaseCompressionMethod):
"""
仅 ``critical_ada_base_scorer == "compactor"`` 时与 kvpress ``CompactorPress.score`` 一致
(``kvpress_compactor_post_rope``:``blending * l_scores + attn_scores``);其它 base(如 SnapKV)
走对应单一 ScorerPress,再叠 CriticalAda。须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
cc = context.compression_context
base = (
getattr(cc, "critical_ada_base_scorer", "compactor")
if cc is not None
else "compactor"
)
if str(base).lower() == "compactor":
return CompactorCompression.pre_rope_scoring(q, k, v, context)
return SnapKVCompression.pre_rope_scoring(q, k, v, context)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
assert compression_context is not None
base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
if base == "compactor":
# 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
blending = resolve_kvpress_compactor_blending(compression_context)
base_scores = maybe_execute_in_stream(
kvpress_compactor_post_rope,
q,
k,
v,
context.cu_seqlens_q,
pre_rope_scores,
compression_context,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
blending=float(blending),
STORE_STREAM=context.STORE_STREAM,
)
else:
base_scores = SnapKVCompression.post_rope_scoring(
q, k, v, pre_rope_scores, context
)
wo_weight = compression_context.wo_weight
if wo_weight is None:
return base_scores
scores, _masked = maybe_execute_in_stream(
critical_ada_key_scores,
q,
k,
v,
wo_weight,
context.cu_seqlens_q,
base_scores,
compression_context,
STORE_STREAM=context.STORE_STREAM,
store_stream=context.STORE_STREAM,
)
return scores
@staticmethod
def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if not hasattr(module, "o_proj") or module.o_proj.weight is None:
return
if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
return
wo_raw = module.o_proj.weight.data
hidden_size, _ = wo_raw.shape
Hq = module.num_heads
head_dim = module.head_dim
wo = (
wo_raw.transpose(0, 1)
.view(Hq, head_dim, hidden_size)
.to(device=device, dtype=torch.float32)
)
module._critical_ada_wo_weight = wo
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)。Stage1/2 与 kvpress 论文/实现一致;``||Wo@V||_1`` 在 **算法上** 与
``CriticalKVPress.vwl1norm`` 相同(GQA 上逐 query 头 L1 再对组取均值)。**默认用 Triton**
(``_compute_wo_v_l1_kernel``);若需与 PyTorch 逐行对齐,将模块内 ``_USE_WO_L1_REFERENCE_BACKEND`` 改为 ``True`` 即走 ``_vwl1_norm_kvpress_reference``。
注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
import torch
import triton
from triton import language as tl
from transformers.models.llama.modeling_llama import repeat_kv
from compactor_vllm.compression.common import BaseCompressionMethod
from compactor_vllm.compression.compactor import (
CompactorCompression,
non_causal_attn_scores,
)
from compactor_vllm.compression.snapkv import SnapKVCompression
from compactor_vllm.utils.helpers import maybe_execute_in_stream
from compactor_vllm.utils.triton_compat import autotune as triton_autotune
# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
_USE_WO_L1_REFERENCE_BACKEND = False
def _vwl1_norm_kvpress_reference(
values_seg: torch.Tensor,
wo: torch.Tensor,
num_kv_heads: int,
num_query_groups: int,
) -> torch.Tensor:
"""
与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
"""
k_len, Hk, D = values_seg.shape
Hq, D_wo, hidden = wo.shape
assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
# [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
# Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
wo_f = wo
head_list = []
for head in range(Hq):
v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
head_wov = v_h.matmul(wo_f[head, :, :])
head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
head_list.append(head_wov_norm)
stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
return stacked.transpose(0, 1).contiguous()
# ============================================================================
# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
# ============================================================================
@triton_autotune(
configs=[
triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
for bk in [32, 64, 128]
for bd in [32, 64]
for nw in [4, 8]
for ns in [3, 4]
],
key=["Hk", "D", "HIDDEN"],
cache_results=True,
)
@triton.jit
def _compute_wo_v_l1_kernel(
V,
WO,
cu_k,
OUT,
STRIDE_V_NK,
STRIDE_V_HK,
STRIDE_V_D,
STRIDE_WO_HQ,
STRIDE_WO_D,
STRIDE_WO_HID,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
Hq: tl.constexpr,
D: tl.constexpr,
HIDDEN: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
b = tl.program_id(0)
hk = tl.program_id(1)
ks = tl.program_id(2)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
nk = k_beg + nk_off
k_mask = nk < k_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
for g in range(QUERY_GROUP_SIZE):
hq = hk * QUERY_GROUP_SIZE + g
v_ptrs = (
V
+ nk[:, None] * STRIDE_V_NK
+ hk * STRIDE_V_HK
+ tl.arange(0, D)[None, :] * STRIDE_V_D
)
v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
for hid_off in range(0, HIDDEN, BLOCK_D):
hid_idx = hid_off + tl.arange(0, BLOCK_D)
hid_mask = hid_idx < HIDDEN
wo_ptrs = (
WO
+ hq * STRIDE_WO_HQ
+ tl.arange(0, D)[:, None] * STRIDE_WO_D
+ hid_idx[None, :] * STRIDE_WO_HID
)
wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
wov_tile = tl.dot(v_blk, wo_tile)
l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
l1_sum = l1_sum / QUERY_GROUP_SIZE
tl.store(out_ptrs, l1_sum, mask=k_mask)
# ============================================================================
# Triton:Stage 1 保护 + Stage 2 加权融合(逐元素)
# ============================================================================
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
key=["Hk"],
cache_results=True,
)
@triton.jit
def _critical_ada_fuse_kernel(
BASE_SCORES,
WO_V_NORM,
STAGE1_MASK,
cu_k,
OUT,
STRIDE_BS_NK,
STRIDE_BS_HK,
STRIDE_WN_NK,
STRIDE_WN_HK,
STRIDE_S1_NK,
STRIDE_S1_HK,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
EPSILON: tl.constexpr,
Hk: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
base = tl.load(bs_ptrs, mask=kmask, other=0.0)
wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
fused = (base + EPSILON) * wnorm
fused = tl.where(stage1_protect == 1, float("inf"), fused)
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, fused, mask=kmask)
def critical_ada_key_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
wo_weight: torch.Tensor,
cu_seqlens: torch.Tensor,
base_scores: torch.Tensor,
compression_ctx: Any,
*,
store_stream: Optional[torch.cuda.Stream] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
在每条序列上对齐 kvpress ``CriticalAdaKVPress.compress``(整段 ``k_len``、与源实现相同的
top-k / scatter 顺序);仅 base 分数来自 compactor_vllm 的 Compactor/SnapKV。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
"""
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
device = q.device
_, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert D == Dk and Hq % Hk == 0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B = cu_seqlens.numel() - 1
G = Hq // Hk
k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
btr = compression_ctx.batch_tokens_to_retain
assert btr is not None and btr.numel() == B
btr = btr.to(device=device, dtype=torch.int32)
epsilon = compression_ctx.critical_ada_epsilon
first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
if wo_weight.dim() == 2:
hidden_size, _ = wo_weight.shape
wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
else:
wo = wo_weight.contiguous()
hidden_size = wo.size(-1)
wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
if B > 0 and int(k_lengths.max().item()) > 0:
if _USE_WO_L1_REFERENCE_BACKEND:
for b in range(B):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
if k_end <= k_beg:
continue
v_seg = v[k_beg:k_end, :, :].contiguous()
wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
v_seg, wo, Hk, G
)
else:
def grid_wo(META):
max_k_len = int(k_lengths.max().item())
return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
_compute_wo_v_l1_kernel[grid_wo](
v,
wo,
cu_seqlens,
wo_v_norm,
*v.stride(),
*wo.stride(),
*wo_v_norm.stride(),
Hk=Hk,
Hq=Hq,
D=D,
HIDDEN=hidden_size,
QUERY_GROUP_SIZE=G,
)
stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
head_budgets_by_batch: list[Optional[torch.Tensor]] = []
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
head_budgets_by_batch.append(None)
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
keep_pairs = int(btr[b].item())
scores_seg = base_scores[k_beg:k_end, :]
# 与 kvpress 的 n_kept 一致:每头保留 n_kept 个 token
n_kept_tokens = max(1, keep_pairs // Hk)
n_kept_tokens = min(n_kept_tokens, k_len)
# kvpress:topk 在「未改动的」scores 上取索引,scatter 只写在副本上,供 head_budgets 用;
# Stage1 仍用原始 scores_seg(见下)。
working = scores_seg.clone()
n_safe = int(n_kept_tokens * alpha_safeguard)
if n_safe > 0:
nk = min(n_safe, k_len)
for hk in range(Hk):
top_idx = torch.topk(scores_seg[:, hk], nk, sorted=True).indices
working[:, hk].scatter_(0, top_idx, float("inf"))
top_pairs = min(n_kept_tokens * Hk, working.numel())
if top_pairs <= 0:
head_budgets_by_batch.append(None)
continue
top_idx_flat = torch.topk(working.reshape(-1), top_pairs, sorted=False).indices
top_head_idx = top_idx_flat % Hk
head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
head_budgets_by_batch.append(head_budgets)
# Stage 1:与 kvpress 相同 — 先 topk(..., M1, sorted=True),再每头取前 phase1 个下标
head_selection_budget_1st = (
(head_budgets.to(torch.float32) * float(first_stage_ratio))
.to(torch.int64)
.tolist()
)
M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
if M1 > 0:
mk = min(M1, k_len)
for hk in range(Hk):
phase1_budget = int(head_selection_budget_1st[hk])
if phase1_budget <= 0:
continue
full_idx = torch.topk(scores_seg[:, hk], mk, sorted=True).indices
take = min(phase1_budget, mk)
stage1_mask[k_beg + full_idx[:take], hk] = 1
final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
def grid_fuse(_META):
return (B, Hk)
_critical_ada_fuse_kernel[grid_fuse](
base_scores,
wo_v_norm,
stage1_mask,
cu_seqlens,
final_scores,
*base_scores.stride(),
*wo_v_norm.stride(),
*stage1_mask.stride(),
*final_scores.stride(),
Hk=Hk,
EPSILON=float(epsilon),
)
# Stage 2(kvpress):对融合后分数先 topk(..., M2, sorted=True),再每头取前 budget 个下标置 inf
for b in range(B):
hb = head_budgets_by_batch[b]
if hb is None:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
k_len = k_end - k_beg
if k_len <= 0:
continue
fused_seg = final_scores[k_beg:k_end, :]
M2 = int(hb.max().item())
if M2 <= 0:
continue
mk = min(M2, k_len)
for hk in range(Hk):
budget = int(hb[hk].item())
if budget <= 0:
continue
full_idx = torch.topk(fused_seg[:, hk], mk, sorted=True).indices
take = min(budget, mk)
final_scores[k_beg + full_idx[:take], hk] = float("inf")
masked_key_indices = None
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
continue
keep_pairs = int(btr[b].item())
total_pairs = k_len * Hk
if keep_pairs >= total_pairs:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
if n_prune_pairs <= 0:
continue
flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
prune_idx = torch.topk(
-flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
).indices
batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
head_idx = prune_idx % Hk
seq_idx = prune_idx // Hk + k_beg
if masked_key_indices is None:
masked_key_indices = (batch_idx, head_idx, seq_idx)
else:
masked_key_indices = (
torch.cat([masked_key_indices[0], batch_idx]),
torch.cat([masked_key_indices[1], head_idx]),
torch.cat([masked_key_indices[2], seq_idx]),
)
if store_stream is not None:
final_scores.record_stream(store_stream)
return final_scores, masked_key_indices
class CriticalAdaKVCompression(BaseCompressionMethod):
"""
以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
cc = context.compression_context
base = getattr(cc, "critical_ada_base_scorer", "snapkv") if cc is not None else "compactor"
if str(base).lower() == "snapkv":
return SnapKVCompression.pre_rope_scoring(q, k, v, context)
return CompactorCompression.pre_rope_scoring(q, k, v, context)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
assert compression_context is not None
base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
if base == "snapkv":
base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
else:
# 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
# maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
# 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
base_scores = maybe_execute_in_stream(
non_causal_attn_scores,
q,
k,
v,
context.cu_seqlens_q,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
sm_scale=1.0,
normalize=True,
accum_scores=pre_rope_scores,
context_lens=compression_context.context_lens,
protected_first_tokens=compression_context.protected_first_tokens,
protected_last_tokens=compression_context.protected_last_tokens,
accum_blending=0.5,
)
wo_weight = compression_context.wo_weight
if wo_weight is None:
return base_scores
scores, _masked = maybe_execute_in_stream(
critical_ada_key_scores,
q,
k,
v,
wo_weight,
context.cu_seqlens_q,
base_scores,
compression_context,
STORE_STREAM=context.STORE_STREAM,
store_stream=context.STORE_STREAM,
)
return scores
@staticmethod
def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if not hasattr(module, "o_proj") or module.o_proj.weight is None:
return
if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
return
wo_raw = module.o_proj.weight.data
hidden_size, _ = wo_raw.shape
Hq = module.num_heads
head_dim = module.head_dim
wo = (
wo_raw.transpose(0, 1)
.view(Hq, head_dim, hidden_size)
.to(device=device, dtype=torch.float32)
)
module._critical_ada_wo_weight = wo
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Compactor-style sparse prefill: Triton varlen attention + paged KV store.
Migrated kernels: ``sparse_varlen_kernel.causal_sparse_varlen_with_cache`` and
``store_kv_cache.prefill_store_topk_kv``.
Layout: MQA uses ``flatten_kv_cache_plane``; GQA/MHA uses head-major flatten
(see ``layout_bridge``).
Execution order note: vLLM runs ``unified_kv_cache_update`` (writes KV) before
``unified_attention_with_output``. Compactor's sparse attention kernel assumes
the paged cache holds only the prefix *before* the current K/V append, while
K_app carries the new tokens. That differs from vLLM's order (cache already
contains the current step after reshape). Therefore ``try_sparse_prefill_forward``
is provided as a reference / future hook and is not invoked from the default
FlashAttention forward path; prefill KV pruning uses ``prefill_store_topk_kv``
in ``do_kv_cache_update_kv_prune`` instead.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from vllm.forward_context import get_forward_context
from vllm.kvprune.compression.prefill_registry import try_topk_indices_from_registry
from vllm.kvprune.core.compression_bridge import compression_method_id_to_enum
from vllm.kvprune.core.runtime import get_kv_prune_state, layer_index_from_layer_name
from vllm.kvprune.utils.layout_bridge import (
block_table_to_global_page_table,
build_batch_mapping,
build_page_table_head_major,
flatten_kv_cache_head_major,
flatten_kv_cache_plane,
write_head_major_flat_to_interleaved,
)
from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv
if TYPE_CHECKING:
from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl, FlashAttentionMetadata
_RATIO_EPS = 1.0e-6
def _get_flash_attn_metadata(layer_name: str) -> "FlashAttentionMetadata | None":
try:
fc = get_forward_context()
except AssertionError:
return None
am = fc.attn_metadata
if isinstance(am, list):
if not am:
return None
am = am[0]
meta = am.get(layer_name)
return meta
def try_sparse_prefill_forward(
impl: "FlashAttentionImpl",
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
attn_metadata: "FlashAttentionMetadata",
output: torch.Tensor,
num_actual_tokens: int,
) -> bool:
"""Run compactor ``causal_sparse_varlen_with_cache`` when eligible. Returns True if ran."""
state = get_kv_prune_state()
if state is None or not state.is_prefill:
return False
comp = state.compression_ratio_gpu[: state.num_reqs]
pruned = comp < 1.0 - _RATIO_EPS
if not torch.any(pruned):
return False
mids = state.compression_method_id_gpu[: state.num_reqs]
if torch.unique(mids).numel() > 1:
return False
# Mixed pruned + non-pruned requests: keep default FlashAttention path for now.
if torch.any(pruned) and torch.any(~pruned):
return False
if impl.num_kv_heads != 1:
return False
if impl.kv_cache_dtype.startswith("fp8"):
return False
if impl.alibi_slopes is not None:
return False
if impl.sliding_window != (-1, -1):
return False
d = impl.head_size
if d <= 0 or (d & (d - 1)) != 0:
return False
num_reqs = state.num_reqs
cu = state.query_start_loc[: num_reqs + 1].to(device=query.device, dtype=torch.int32)
seq_lens = attn_metadata.seq_lens[:num_reqs].to(torch.int32)
seqlen_q = cu[1:] - cu[:-1]
cached = seq_lens - seqlen_q
if torch.any(cached < 0):
return False
seq_lens_bh = cached.unsqueeze(1).expand(-1, 1).contiguous()
block_table = attn_metadata.block_table[:num_reqs]
max_batches = block_table.shape[0]
n_lp = block_table.shape[1]
global_page_table = block_table_to_global_page_table(
block_table, impl.num_kv_heads, max_batches=max_batches
)
batch_mapping = build_batch_mapping(num_reqs, query.device)
try:
k_flat, v_flat = flatten_kv_cache_plane(key_cache, value_cache, impl.num_kv_heads)
except ValueError:
return False
page_size = key_cache.shape[1]
if page_size <= 0 or k_flat.shape[0] % page_size != 0:
return False
q3 = query[:num_actual_tokens].view(num_actual_tokens, impl.num_heads, d)
k3 = key[:num_actual_tokens].view(num_actual_tokens, 1, d)
v3 = value[:num_actual_tokens].view(num_actual_tokens, 1, d)
max_seqlen_q = int(attn_metadata.max_query_len)
max_cached = int(seq_lens_bh.max().item()) if seq_lens_bh.numel() else 0
out = causal_sparse_varlen_with_cache(
q3,
k3,
v3,
k_flat,
v_flat,
seq_lens_bh,
global_page_table,
batch_mapping,
cu,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_cache=max_cached,
HKV=1,
PAGE_SIZE=page_size,
sm_scale=None,
)
output[:num_actual_tokens].copy_(out.reshape(num_actual_tokens, impl.num_heads * d))
return True
def _build_tail_topk_indices(
cu_seqlens: torch.Tensor,
num_reqs: int,
hkv: int,
compression_ratio: float | torch.Tensor,
max_sel: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (indices [B, max_sel], num_pairs_to_retain [B]) for tail tokens × heads."""
indices = torch.zeros(num_reqs, max_sel, dtype=torch.int32, device=device)
n_pairs = torch.zeros(num_reqs, dtype=torch.int32, device=device)
cu_cpu = cu_seqlens[: num_reqs + 1].detach()
for b in range(num_reqs):
start = int(cu_cpu[b].item())
end = int(cu_cpu[b + 1].item())
chunk_len = end - start
if chunk_len <= 0:
continue
if isinstance(compression_ratio, torch.Tensor):
r_b = float(compression_ratio[b].item())
else:
r_b = compression_ratio
k_tok = max(1, int(round(chunk_len * r_b)))
k_tok = min(k_tok, chunk_len)
pairs: list[int] = []
for tok in range(end - k_tok, end):
for h in range(hkv):
pairs.append(tok * hkv + h)
if len(pairs) >= max_sel:
break
if len(pairs) >= max_sel:
break
n = len(pairs)
if n > 0:
indices[b, :n] = torch.tensor(pairs, dtype=torch.int32, device=device)
n_pairs[b] = n
return indices, n_pairs
def try_prefill_kv_store(
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
) -> bool:
"""Top-k or full compactor prefill KV store; updates per-layer logical lengths."""
state = get_kv_prune_state()
if state is None or not state.is_prefill:
return False
num_reqs = state.num_reqs
comp = state.compression_ratio_gpu[:num_reqs]
pruned = comp < 1.0 - _RATIO_EPS
if not torch.any(pruned):
return False
if torch.any(pruned) and torch.any(~pruned):
return False
mids = state.compression_method_id_gpu[:num_reqs]
if torch.unique(mids).numel() > 1:
return False
meta = _get_flash_attn_metadata(layer.layer_name)
if meta is None:
return False
num_kv_heads = key.shape[1]
d = key.shape[2]
if d <= 0 or (d & (d - 1)) != 0:
return False
key_cache, value_cache = kv_cache.unbind(0)
page_size = key_cache.shape[1]
nb = key_cache.shape[0]
bs = key_cache.shape[1]
head_major = num_kv_heads > 1
try:
if head_major:
k_flat, v_flat = flatten_kv_cache_head_major(key_cache, value_cache)
else:
k_flat, v_flat = flatten_kv_cache_plane(
key_cache, value_cache, num_kv_heads
)
except ValueError:
return False
block_table = meta.block_table[:num_reqs]
max_batches = block_table.shape[0]
if head_major:
global_page_table = build_page_table_head_major(
block_table,
num_kv_heads,
num_blocks=nb,
block_size=bs,
page_size=page_size,
max_batches=max_batches,
)
else:
global_page_table = block_table_to_global_page_table(
block_table, num_kv_heads, max_batches=max_batches
)
batch_mapping = build_batch_mapping(num_reqs, key.device)
cu = state.query_start_loc[: num_reqs + 1].to(device=key.device, dtype=torch.int32)
seq_lens = meta.seq_lens[:num_reqs].to(torch.int32)
seqlen_q = cu[1:] - cu[:-1]
cached = (seq_lens - seqlen_q).unsqueeze(1).expand(-1, num_kv_heads).contiguous()
layer_idx = layer_index_from_layer_name(layer.layer_name)
max_seqlen_k = int(seqlen_q.max().item()) if seqlen_q.numel() else 0
max_sel = min(max_seqlen_k * num_kv_heads, 8192)
max_sel = max(max_sel, 1)
mid = int(state.compression_method_id_gpu[0].item())
method_enum = compression_method_id_to_enum(mid)
registry_out = try_topk_indices_from_registry(
method_enum, key, value, cu, num_reqs, comp, max_sel, key.device
)
if registry_out is not None:
indices, n_pairs = registry_out
else:
indices, n_pairs = _build_tail_topk_indices(
cu, num_reqs, num_kv_heads, comp, max_sel, key.device
)
bh = cached.clone()
prefill_store_topk_kv(
new_keys=key,
new_vals=value,
indices_topk=indices,
num_tokens_to_retain=n_pairs,
page_table=global_page_table,
batch_mapping=batch_mapping,
bh_lens=bh,
k_cache=k_flat,
v_cache=v_flat,
PAGE_SIZE=page_size,
PAD_TO_PAGE_SIZE=False,
cu_seqlens_k=None,
)
if head_major:
write_head_major_flat_to_interleaved(
k_flat, v_flat, key_cache, value_cache
)
new_lens = bh.to(torch.int32)
if state.logical_seq_lens_gpu.dim() == 3:
state.logical_seq_lens_gpu[layer_idx, :num_reqs, :] = new_lens
else:
state.logical_seq_lens_gpu[layer_idx, :num_reqs] = new_lens.max(
dim=1
).values
return True
__all__ = [
"try_sparse_prefill_forward",
"try_prefill_kv_store",
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment