Commit d29c39ca authored by chenzk's avatar chenzk
Browse files

vllm kvprune wo:v1.1.0

parent f81ce56b
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)及首/尾保护段长度。
注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
import torch
import triton
from triton import language as tl
from compactor_vllm.compression.common import BaseCompressionMethod
from compactor_vllm.compression.compactor import (
CompactorCompression,
non_causal_attn_scores,
)
from compactor_vllm.compression.snapkv import SnapKVCompression
from compactor_vllm.utils.helpers import maybe_execute_in_stream
from compactor_vllm.utils.triton_compat import autotune as triton_autotune
# ============================================================================
# Triton Kernel 1: 计算 ||Wo @ V||₁ (L1 范数)
# ============================================================================
@triton_autotune(
configs=[
triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
for bk in [32, 64, 128]
for bd in [32, 64]
for nw in [4, 8]
for ns in [3, 4]
],
key=["Hk", "D", "HIDDEN"],
cache_results=True,
)
@triton.jit
def _compute_wo_v_l1_kernel(
V,
WO,
cu_k,
OUT,
STRIDE_V_NK,
STRIDE_V_HK,
STRIDE_V_D,
STRIDE_WO_HQ,
STRIDE_WO_D,
STRIDE_WO_HID,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
Hq: tl.constexpr,
D: tl.constexpr,
HIDDEN: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
ks = tl.program_id(2)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
nk = k_beg + nk_off
k_mask = nk < k_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
for g in range(QUERY_GROUP_SIZE):
hq = hk * QUERY_GROUP_SIZE + g
v_ptrs = (
V
+ nk[:, None] * STRIDE_V_NK
+ hk * STRIDE_V_HK
+ tl.arange(0, D)[None, :] * STRIDE_V_D
)
v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
for hid_off in range(0, HIDDEN, BLOCK_D):
hid_idx = hid_off + tl.arange(0, BLOCK_D)
hid_mask = hid_idx < HIDDEN
wo_ptrs = (
WO
+ hq * STRIDE_WO_HQ
+ tl.arange(0, D)[:, None] * STRIDE_WO_D
+ hid_idx[None, :] * STRIDE_WO_HID
)
wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
wov_tile = tl.dot(v_blk, wo_tile)
l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
l1_sum = l1_sum / QUERY_GROUP_SIZE
tl.store(out_ptrs, l1_sum, mask=k_mask)
# ============================================================================
# Triton Kernel 2: Stage 1 保护 + Stage 2 加权融合
# ============================================================================
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
key=["Hk"],
cache_results=True,
)
@triton.jit
def _critical_ada_fuse_kernel(
BASE_SCORES,
WO_V_NORM,
STAGE1_MASK,
cu_k,
OUT,
EPSILON: tl.constexpr,
STRIDE_BS_NK,
STRIDE_BS_HK,
STRIDE_WN_NK,
STRIDE_WN_HK,
STRIDE_S1_NK,
STRIDE_S1_HK,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
base = tl.load(bs_ptrs, mask=kmask, other=0.0)
wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
fused = (base + EPSILON) * wnorm
fused = tl.where(stage1_protect == 1, float("inf"), fused)
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, fused, mask=kmask)
def critical_ada_key_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
wo_weight: torch.Tensor,
cu_seqlens: torch.Tensor,
base_scores: torch.Tensor,
compression_ctx: Any,
*,
store_stream: Optional[torch.cuda.Stream] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
在每条序列上尽量贴近 kvpress 的 CriticalAdaKV 语义:
1) alpha_safeguard 安全预算(每头至少保留一部分);
2) 基于 base_scores 的 head-wise 自适应预算分配(head_budgets);
3) Stage-1 按 head_budgets * first_stage_ratio 保护;
4) Stage-2 计算 ``(base + eps) * ||Wo@V||_1``,再按 head_budgets 做每头 top-k 保护。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``、``protected_first_tokens``、``protected_last_tokens``;
可选 ``critical_ada_epsilon``、``critical_ada_first_stage_ratio``、
``critical_ada_alpha_safeguard``。
"""
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
device = q.device
_, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert D == Dk and Hq % Hk == 0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B = cu_seqlens.numel() - 1
G = Hq // Hk
k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
btr = compression_ctx.batch_tokens_to_retain
assert btr is not None and btr.numel() == B
btr = btr.to(device=device, dtype=torch.int32)
prot_first = compression_ctx.protected_first_tokens or [0] * B
prot_last = compression_ctx.protected_last_tokens or [0] * B
epsilon = compression_ctx.critical_ada_epsilon
first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
alpha_safeguard = float(getattr(compression_ctx, "critical_ada_alpha_safeguard", 0.2))
alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
if wo_weight.dim() == 2:
hidden_size, _ = wo_weight.shape
wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
else:
wo = wo_weight.contiguous()
hidden_size = wo.size(-1)
wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
def grid_wo(META):
max_k_len = int(k_lengths.max().item())
return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
_compute_wo_v_l1_kernel[grid_wo](
v,
wo,
cu_seqlens,
wo_v_norm,
*v.stride(),
*wo.stride(),
*wo_v_norm.stride(),
Hk=Hk,
Hq=Hq,
D=D,
HIDDEN=hidden_size,
QUERY_GROUP_SIZE=G,
)
stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
# kvpress 风格的每头预算(按序列自适应),用于 Stage-1/Stage-2。
head_budgets_by_batch = []
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
head_budgets_by_batch.append(None)
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
s = int(prot_first[b]) if b < len(prot_first) else 0
e = int(prot_last[b]) if b < len(prot_last) else 0
lo, hi = k_beg + s, k_end - e
compressible = max(0, hi - lo)
keep_pairs = int(btr[b].item())
if compressible <= 0:
head_budgets_by_batch.append(None)
continue
# 每头 token 预算(kvpress 的 n_kept)
n_kept_tokens = max(1, keep_pairs // Hk)
n_kept_tokens = min(n_kept_tokens, compressible)
# 安全预算(每头至少保留 n_safe)
n_safe = int(n_kept_tokens * alpha_safeguard)
if n_safe > 0:
tk_safe = min(n_safe, compressible)
for hk in range(Hk):
safe_idx = torch.topk(base_scores[lo:hi, hk], tk_safe, sorted=False).indices
stage1_mask[lo + safe_idx, hk] = 1
# 自适应预算分配:在扁平 (token, head) 空间取 top n_kept_tokens*Hk,统计每个 head 的预算
budget_scores = base_scores[lo:hi, :].clone()
if n_safe > 0:
budget_scores[stage1_mask[lo:hi, :] == 1] = float("inf")
top_pairs = min(n_kept_tokens * Hk, budget_scores.numel())
if top_pairs <= 0:
head_budgets_by_batch.append(None)
continue
top_idx_flat = torch.topk(
budget_scores.reshape(-1), top_pairs, sorted=False
).indices
top_head_idx = top_idx_flat % Hk
head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
head_budgets_by_batch.append(head_budgets)
# Stage-1:按 head_budgets 的 first_stage_ratio 分头保护(kvpress 语义)
for hk in range(Hk):
phase1_budget = int(head_budgets[hk].item() * first_stage_ratio)
if phase1_budget <= 0:
continue
tk = min(phase1_budget, compressible)
top_idx = torch.topk(base_scores[lo:hi, hk], tk, sorted=False).indices
stage1_mask[lo + top_idx, hk] = 1
final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
def grid_fuse(_META):
return (B, Hk)
_critical_ada_fuse_kernel[grid_fuse](
base_scores,
wo_v_norm,
stage1_mask,
cu_seqlens,
final_scores,
EPSILON=epsilon,
*base_scores.stride(),
*wo_v_norm.stride(),
*stage1_mask.stride(),
*final_scores.stride(),
Hk=Hk,
)
# Stage-2(kvpress 语义):在融合后按每头预算再做一次 top-k 保护。
for b in range(B):
hb = head_budgets_by_batch[b]
if hb is None:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
s = int(prot_first[b]) if b < len(prot_first) else 0
e = int(prot_last[b]) if b < len(prot_last) else 0
lo, hi = k_beg + s, k_end - e
if hi <= lo:
continue
region_len = hi - lo
for hk in range(Hk):
budget = int(hb[hk].item())
if budget <= 0:
continue
tk = min(budget, region_len)
idx = torch.topk(final_scores[lo:hi, hk], tk, sorted=False).indices
final_scores[lo + idx, hk] = float("inf")
masked_key_indices = None
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
continue
keep_pairs = int(btr[b].item())
total_pairs = k_len * Hk
if keep_pairs >= total_pairs:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
if n_prune_pairs <= 0:
continue
flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
prune_idx = torch.topk(
-flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
).indices
batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
head_idx = prune_idx % Hk
seq_idx = prune_idx // Hk + k_beg
if masked_key_indices is None:
masked_key_indices = (batch_idx, head_idx, seq_idx)
else:
masked_key_indices = (
torch.cat([masked_key_indices[0], batch_idx]),
torch.cat([masked_key_indices[1], head_idx]),
torch.cat([masked_key_indices[2], seq_idx]),
)
if store_stream is not None:
final_scores.record_stream(store_stream)
return final_scores, masked_key_indices
class CriticalAdaKVCompression(BaseCompressionMethod):
"""
以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
cc = context.compression_context
base = getattr(cc, "critical_ada_base_scorer", "compactor") if cc is not None else "compactor"
if str(base).lower() == "snapkv":
return SnapKVCompression.pre_rope_scoring(q, k, v, context)
return CompactorCompression.pre_rope_scoring(q, k, v, context)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
assert compression_context is not None
base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
if base == "snapkv":
base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
else:
# 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
# maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
# 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
base_scores = maybe_execute_in_stream(
non_causal_attn_scores,
q,
k,
v,
context.cu_seqlens_q,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
sm_scale=1.0,
normalize=True,
accum_scores=pre_rope_scores,
context_lens=compression_context.context_lens,
protected_first_tokens=compression_context.protected_first_tokens,
protected_last_tokens=compression_context.protected_last_tokens,
accum_blending=0.5,
)
wo_weight = compression_context.wo_weight
if wo_weight is None:
return base_scores
scores, _masked = maybe_execute_in_stream(
critical_ada_key_scores,
q,
k,
v,
wo_weight,
context.cu_seqlens_q,
base_scores,
compression_context,
STORE_STREAM=context.STORE_STREAM,
store_stream=context.STORE_STREAM,
)
return scores
@staticmethod
def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if not hasattr(module, "o_proj") or module.o_proj.weight is None:
return
if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
return
wo_raw = module.o_proj.weight.data
hidden_size, _ = wo_raw.shape
Hq = module.num_heads
head_dim = module.head_dim
wo = (
wo_raw.transpose(0, 1)
.view(Hq, head_dim, hidden_size)
.to(device=device, dtype=torch.float32)
)
module._critical_ada_wo_weight = wo
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 vllm.kvprune 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)。CriticalAda 主链在 **PyTorch** 中与 kvpress ``CriticalAdaKVPress.compress``
对齐;``||Wo@V||_1`` 仍默认用 Triton ``_compute_wo_v_l1_kernel``(与 ``CriticalKVPress.vwl1norm`` 同式)。
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 可改走 ``_vwl1_norm_kvpress_reference``。
注意:不得在 import 时加载 ``vllm.kvprune.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
import torch
import triton
from triton import language as tl
from transformers.models.llama.modeling_llama import repeat_kv
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.kvprune.compression.compactor import (
CompactorCompression,
kvpress_compactor_post_rope,
resolve_kvpress_compactor_blending,
)
from vllm.kvprune.compression.snapkv import SnapKVCompression
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
def _criticalkv_prune_hip_pipeline(configs, _, **kwargs):
"""HIP: TritonHCUGPUStreamPipelineV2 breaks on nested loops + hid_idx arange (see snapkv)."""
if torch.version.hip is None:
return list(configs)
return [c for c in configs if getattr(c, "num_stages", 1) == 1]
def _compute_wo_v_l1_autotune_configs():
"""CUDA: full autotune. HIP: single num_stages=1 config (avoids pipeliner + long autotune)."""
if torch.version.hip is not None:
return [
triton.Config(
{"BLOCK_K": 64, "BLOCK_D": 64}, num_warps=4, num_stages=1
),
]
return [
triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
for bk in [32, 64, 128]
for bd in [32, 64]
for nw in [4, 8]
for ns in [3, 4]
]
# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
_USE_WO_L1_REFERENCE_BACKEND = False
def _vwl1_norm_kvpress_reference(
values_seg: torch.Tensor,
wo: torch.Tensor,
num_kv_heads: int,
num_query_groups: int,
) -> torch.Tensor:
"""
与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
"""
k_len, Hk, D = values_seg.shape
Hq, D_wo, hidden = wo.shape
assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
# [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
# Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
wo_f = wo
head_list = []
for head in range(Hq):
v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
head_wov = v_h.matmul(wo_f[head, :, :])
head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
head_list.append(head_wov_norm)
stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
return stacked.transpose(0, 1).contiguous()
# ============================================================================
# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
# ============================================================================
@triton_autotune(
configs=_compute_wo_v_l1_autotune_configs(),
key=["Hk", "D", "HIDDEN"],
cache_results=True,
prune_configs_by={"early_config_prune": _criticalkv_prune_hip_pipeline},
)
@triton.jit
def _compute_wo_v_l1_kernel(
V,
WO,
cu_k,
OUT,
STRIDE_V_NK,
STRIDE_V_HK,
STRIDE_V_D,
STRIDE_WO_HQ,
STRIDE_WO_D,
STRIDE_WO_HID,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
Hq: tl.constexpr,
D: tl.constexpr,
HIDDEN: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
b = tl.program_id(0)
hk = tl.program_id(1)
ks = tl.program_id(2)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
nk = k_beg + nk_off
k_mask = nk < k_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
for g in range(QUERY_GROUP_SIZE):
hq = hk * QUERY_GROUP_SIZE + g
v_ptrs = (
V
+ nk[:, None] * STRIDE_V_NK
+ hk * STRIDE_V_HK
+ tl.arange(0, D)[None, :] * STRIDE_V_D
)
v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
for hid_off in range(0, HIDDEN, BLOCK_D):
hid_idx = hid_off + tl.arange(0, BLOCK_D)
hid_mask = hid_idx < HIDDEN
wo_ptrs = (
WO
+ hq * STRIDE_WO_HQ
+ tl.arange(0, D)[:, None] * STRIDE_WO_D
+ hid_idx[None, :] * STRIDE_WO_HID
)
wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
wov_tile = tl.dot(v_blk, wo_tile)
l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
l1_sum = l1_sum / QUERY_GROUP_SIZE
tl.store(out_ptrs, l1_sum, mask=k_mask)
def critical_ada_key_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
wo_weight: torch.Tensor,
cu_seqlens: torch.Tensor,
base_scores: torch.Tensor,
compression_ctx: Any,
*,
store_stream: Optional[torch.cuda.Stream] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
按 kvpress ``CriticalAdaKVPress.compress`` 的顺序实现:safeguard scatter →
head-major 展平做 head_budgets → Stage1 在 **已抬高** 的分数上 top-k →
``(scores + ε) * ||WoV||₁`` → Stage2 scatter → 最终按 head-major 展平做 bottom-k。
``||Wo@V||₁`` 仍用 Triton(``_compute_wo_v_l1_kernel``);中间 CriticalAda 步骤用 PyTorch
与 kvpress 逐句对齐。仅 base 分数来自 Compactor/SnapKV。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
"""
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
device = q.device
_, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert D == Dk and Hq % Hk == 0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B = cu_seqlens.numel() - 1
G = Hq // Hk
k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
btr = compression_ctx.batch_tokens_to_retain
assert btr is not None and btr.numel() == B
btr = btr.to(device=device, dtype=torch.int32)
epsilon = compression_ctx.critical_ada_epsilon
first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
if wo_weight.dim() == 2:
hidden_size, _ = wo_weight.shape
wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
else:
wo = wo_weight.contiguous()
hidden_size = wo.size(-1)
wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
if B > 0 and int(k_lengths.max().item()) > 0:
if _USE_WO_L1_REFERENCE_BACKEND:
for b in range(B):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
if k_end <= k_beg:
continue
v_seg = v[k_beg:k_end, :, :].contiguous()
wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
v_seg, wo, Hk, G
)
else:
def grid_wo(META):
max_k_len = int(k_lengths.max().item())
return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
_compute_wo_v_l1_kernel[grid_wo](
v,
wo,
cu_seqlens,
wo_v_norm,
*v.stride(),
*wo.stride(),
*wo_v_norm.stride(),
Hk=Hk,
Hq=Hq,
D=D,
HIDDEN=hidden_size,
QUERY_GROUP_SIZE=G,
)
# kvpress 用 finfo.max 抬高分数;与 inf 混用时 topk 行为一致
_score_max = float(torch.finfo(torch.float32).max)
final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
head_budgets_by_batch: list[Optional[torch.Tensor]] = []
for b in range(B):
k_len = int(k_lengths[b].item())
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
if k_len == 0:
head_budgets_by_batch.append(None)
continue
scores_seg = base_scores[k_beg:k_end, :].float()
keep_pairs = int(btr[b].item())
n_kept_tokens = max(1, keep_pairs // Hk)
n_kept_tokens = min(n_kept_tokens, k_len)
# scores_work: 布局 [k_len, Hk],对应 kvpress [bsz=1, H, k_len] 的 transpose(0,2) 视角下沿 token 维的 topk
scores_work = scores_seg.clone()
# --- Alpha safeguard(kvpress L148–152)---
n_safe = int(n_kept_tokens * alpha_safeguard)
nk = min(n_safe, k_len) if n_safe > 0 else 0
if nk > 0:
for hk in range(Hk):
top_idx = torch.topk(scores_work[:, hk], nk, dim=0, largest=True).indices
scores_work[top_idx, hk] = _score_max
# --- Head budgets:kvpress L158–164,展平顺序与 [bsz, H, k_len] 一致(head-major:h*K + t)---
top_pairs = min(n_kept_tokens * Hk, k_len * Hk)
if top_pairs <= 0:
head_budgets_by_batch.append(None)
wn = wo_v_norm[k_beg:k_end, :]
final_scores[k_beg:k_end, :] = (scores_seg + epsilon) * wn
continue
budget_flat = scores_work.permute(1, 0).contiguous().reshape(-1)
top_idx_flat = torch.topk(
budget_flat, top_pairs, largest=True, sorted=False
).indices
top_head_idx = top_idx_flat // k_len
head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int64)
head_budgets_by_batch.append(head_budgets)
# --- Stage 1(kvpress L166–171):在已 safeguard 的 scores_work 上沿 token 维 top-k ---
head_selection_budget_1st = (
(head_budgets.to(torch.float32) * float(first_stage_ratio))
.to(torch.int64)
.tolist()
)
M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
mk = min(M1, k_len) if M1 > 0 else 0
if mk > 0:
top_k_index = torch.topk(scores_work, mk, dim=0, largest=True, sorted=True).indices
for hk in range(Hk):
phase1_budget = int(head_selection_budget_1st[hk])
if phase1_budget <= 0:
continue
take = min(phase1_budget, mk)
scores_work[top_k_index[:take, hk], hk] = _score_max
# --- Stage 2 重加权(kvpress L173–175)---
wn = wo_v_norm[k_beg:k_end, :]
scores_fused = (scores_work + epsilon) * wn
# --- Stage 2 scatter(kvpress L176–179)---
M2 = int(head_budgets.max().item())
mk2 = min(M2, k_len) if M2 > 0 else 0
if mk2 > 0:
top_k_index2 = torch.topk(
scores_fused, mk2, dim=0, largest=True, sorted=True
).indices
for hk in range(Hk):
budget = int(head_budgets[hk].item())
if budget <= 0:
continue
take = min(budget, mk2)
scores_fused[top_k_index2[:take, hk], hk] = _score_max
final_scores[k_beg:k_end, :] = scores_fused
masked_key_indices = None
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
continue
keep_pairs = int(btr[b].item())
total_pairs = k_len * Hk
if keep_pairs >= total_pairs:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
if n_prune_pairs <= 0:
continue
# kvpress L187:``scores.reshape(bsz, -1)`` 即 [H, K] 按 head-major 展平(flat = h*K + t)
flat_scores = (
final_scores[k_beg:k_end, :].permute(1, 0).contiguous().reshape(-1)
)
prune_idx = torch.topk(
-flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
).indices
batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
head_idx = prune_idx // k_len
seq_idx = prune_idx % k_len + k_beg
if masked_key_indices is None:
masked_key_indices = (batch_idx, head_idx, seq_idx)
else:
masked_key_indices = (
torch.cat([masked_key_indices[0], batch_idx]),
torch.cat([masked_key_indices[1], head_idx]),
torch.cat([masked_key_indices[2], seq_idx]),
)
if store_stream is not None:
final_scores.record_stream(store_stream)
return final_scores, masked_key_indices
class CriticalAdaKVCompression(BaseCompressionMethod):
"""
仅 ``critical_ada_base_scorer == "compactor"`` 时与 kvpress ``CompactorPress.score`` 一致
(``kvpress_compactor_post_rope``:``blending * l_scores + attn_scores``);其它 base(如 SnapKV)
走对应单一 ScorerPress,再叠 CriticalAda。须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
cc = context.compression_context
base = (
getattr(cc, "critical_ada_base_scorer", "compactor")
if cc is not None
else "compactor"
)
if str(base).lower() == "compactor":
return CompactorCompression.pre_rope_scoring(q, k, v, context)
return SnapKVCompression.pre_rope_scoring(q, k, v, context)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
assert compression_context is not None
base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
if base == "compactor":
# 特例:与 ``CompactorPress.score`` / ``CompactorCompression.post_rope_scoring`` 一致。
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
blending = resolve_kvpress_compactor_blending(compression_context)
base_scores = maybe_execute_in_stream(
kvpress_compactor_post_rope,
q,
k,
v,
context.cu_seqlens_q,
pre_rope_scores,
compression_context,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
blending=float(blending),
STORE_STREAM=context.STORE_STREAM,
)
else:
base_scores = SnapKVCompression.post_rope_scoring(
q, k, v, pre_rope_scores, context
)
wo_weight = compression_context.wo_weight
if wo_weight is None:
return base_scores
scores, _masked = maybe_execute_in_stream(
critical_ada_key_scores,
q,
k,
v,
wo_weight,
context.cu_seqlens_q,
base_scores,
compression_context,
STORE_STREAM=context.STORE_STREAM,
store_stream=context.STORE_STREAM,
)
return scores
@staticmethod
def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if not hasattr(module, "o_proj") or module.o_proj.weight is None:
return
if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
return
wo_raw = module.o_proj.weight.data
hidden_size, _ = wo_raw.shape
Hq = module.num_heads
head_dim = module.head_dim
wo = (
wo_raw.transpose(0, 1)
.view(Hq, head_dim, hidden_size)
.to(device=device, dtype=torch.float32)
)
module._critical_ada_wo_weight = wo
"""
CriticalAdaKV: 在 Compactor(pre RoPE 杠杆分 + post RoPE 非因果注意力融合)基础上,
用输出投影 Wo 对 Value 的 L1 范数做 Stage-2 重加权;Stage-1 在 Compactor 基础分上做预算内 top-k 保护。
预算与 compactor_vllm 引擎一致:使用 ``compression_context.batch_tokens_to_retain``(flatten 的
(token, head) 对数量)。Stage1/2 与 kvpress 论文/实现一致;``||Wo@V||_1`` 在 **算法上** 与
``CriticalKVPress.vwl1norm`` 相同(GQA 上逐 query 头 L1 再对组取均值)。**默认用 Triton**
(``_compute_wo_v_l1_kernel``);若需与 PyTorch 逐行对齐,将模块内 ``_USE_WO_L1_REFERENCE_BACKEND`` 改为 ``True`` 即走 ``_vwl1_norm_kvpress_reference``。
注意:不得在 import 时加载 ``compactor_vllm.utils.context``(其会再 import ``CompressionMethod``,
与 ``compression/__init__.py`` 导入本模块形成环)。运行时只使用与 ``CompressionContext`` 同字段的 duck 对象。
"""
from __future__ import annotations
from typing import Any, Optional, Tuple
import torch
import triton
from triton import language as tl
from transformers.models.llama.modeling_llama import repeat_kv
from compactor_vllm.compression.common import BaseCompressionMethod
from compactor_vllm.compression.compactor import (
CompactorCompression,
non_causal_attn_scores,
)
from compactor_vllm.compression.snapkv import SnapKVCompression
from compactor_vllm.utils.helpers import maybe_execute_in_stream
from compactor_vllm.utils.triton_compat import autotune as triton_autotune
# Wo@V 的 L1:False = Triton(默认),True = PyTorch 参考(调试/对齐)
_USE_WO_L1_REFERENCE_BACKEND = False
def _vwl1_norm_kvpress_reference(
values_seg: torch.Tensor,
wo: torch.Tensor,
num_kv_heads: int,
num_query_groups: int,
) -> torch.Tensor:
"""
与 kvpress ``CriticalKVPress.vwl1norm`` 等价的 **可选参考实现**(PyTorch,仅用于核对;
将 ``_USE_WO_L1_REFERENCE_BACKEND`` 置为 ``True`` 时选用,默认走 Triton)。
算法:repeat_kv → 逐 query 头 ``|V @ Wo_h|_1`` → 在 GQA 组上 mean,与 Triton 路径同一公式。
"""
k_len, Hk, D = values_seg.shape
Hq, D_wo, hidden = wo.shape
assert D == D_wo and Hk == num_kv_heads and Hq == Hk * num_query_groups
# [1, Hk, k_len, D] 与 HF repeat_kv 约定一致
v_4d = values_seg.permute(1, 0, 2).unsqueeze(0).contiguous()
v_rep = repeat_kv(v_4d, num_query_groups) # [1, Hq, k_len, D]
# Wo 在 attention 里注入为 float32,V 常为 bf16/fp16,matmul 前对齐 dtype
wo_f = wo
head_list = []
for head in range(Hq):
v_h = v_rep[0, head, :, :].to(dtype=wo_f.dtype)
head_wov = v_h.matmul(wo_f[head, :, :])
head_wov_norm = torch.norm(head_wov, p=1, dim=-1)
head_list.append(head_wov_norm)
stacked = torch.stack(head_list, dim=0) # [Hq, k_len]
stacked = stacked.view(Hk, num_query_groups, k_len).mean(dim=1)
return stacked.transpose(0, 1).contiguous()
# ============================================================================
# Triton:||Wo @ V||₁ 按 kvpress 定义(GQA 上对 query 组 L1 后取均值)
# ============================================================================
@triton_autotune(
configs=[
triton.Config({"BLOCK_K": bk, "BLOCK_D": bd}, num_warps=nw, num_stages=ns)
for bk in [32, 64, 128]
for bd in [32, 64]
for nw in [4, 8]
for ns in [3, 4]
],
key=["Hk", "D", "HIDDEN"],
cache_results=True,
)
@triton.jit
def _compute_wo_v_l1_kernel(
V,
WO,
cu_k,
OUT,
STRIDE_V_NK,
STRIDE_V_HK,
STRIDE_V_D,
STRIDE_WO_HQ,
STRIDE_WO_D,
STRIDE_WO_HID,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
Hk: tl.constexpr,
Hq: tl.constexpr,
D: tl.constexpr,
HIDDEN: tl.constexpr,
QUERY_GROUP_SIZE: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""对每个 KV 头:对 G 个 query 头分别算 ``sum(|V @ Wo|)``,再除以 G(与 kvpress mean 一致)。"""
b = tl.program_id(0)
hk = tl.program_id(1)
ks = tl.program_id(2)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
nk_off = ks * BLOCK_K + tl.arange(0, BLOCK_K)
nk = k_beg + nk_off
k_mask = nk < k_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
l1_sum = tl.zeros([BLOCK_K], dtype=tl.float32)
for g in range(QUERY_GROUP_SIZE):
hq = hk * QUERY_GROUP_SIZE + g
v_ptrs = (
V
+ nk[:, None] * STRIDE_V_NK
+ hk * STRIDE_V_HK
+ tl.arange(0, D)[None, :] * STRIDE_V_D
)
v_blk = tl.load(v_ptrs, mask=k_mask[:, None], other=0.0).to(tl.float32)
for hid_off in range(0, HIDDEN, BLOCK_D):
hid_idx = hid_off + tl.arange(0, BLOCK_D)
hid_mask = hid_idx < HIDDEN
wo_ptrs = (
WO
+ hq * STRIDE_WO_HQ
+ tl.arange(0, D)[:, None] * STRIDE_WO_D
+ hid_idx[None, :] * STRIDE_WO_HID
)
wo_tile = tl.load(wo_ptrs, mask=hid_mask[None, :], other=0.0).to(tl.float32)
wov_tile = tl.dot(v_blk, wo_tile)
l1_sum += tl.sum(tl.abs(wov_tile), axis=1)
l1_sum = l1_sum / QUERY_GROUP_SIZE
tl.store(out_ptrs, l1_sum, mask=k_mask)
# ============================================================================
# Triton:Stage 1 保护 + Stage 2 加权融合(逐元素)
# ============================================================================
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128, 256]],
key=["Hk"],
cache_results=True,
)
@triton.jit
def _critical_ada_fuse_kernel(
BASE_SCORES,
WO_V_NORM,
STAGE1_MASK,
cu_k,
OUT,
STRIDE_BS_NK,
STRIDE_BS_HK,
STRIDE_WN_NK,
STRIDE_WN_HK,
STRIDE_S1_NK,
STRIDE_S1_HK,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
EPSILON: tl.constexpr,
Hk: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
bs_ptrs = BASE_SCORES + nk * STRIDE_BS_NK + hk * STRIDE_BS_HK
wn_ptrs = WO_V_NORM + nk * STRIDE_WN_NK + hk * STRIDE_WN_HK
s1_ptrs = STAGE1_MASK + nk * STRIDE_S1_NK + hk * STRIDE_S1_HK
base = tl.load(bs_ptrs, mask=kmask, other=0.0)
wnorm = tl.load(wn_ptrs, mask=kmask, other=1.0)
stage1_protect = tl.load(s1_ptrs, mask=kmask, other=0).to(tl.int32)
fused = (base + EPSILON) * wnorm
fused = tl.where(stage1_protect == 1, float("inf"), fused)
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, fused, mask=kmask)
def critical_ada_key_scores(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
wo_weight: torch.Tensor,
cu_seqlens: torch.Tensor,
base_scores: torch.Tensor,
compression_ctx: Any,
*,
store_stream: Optional[torch.cuda.Stream] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
"""
使用与引擎一致的保留预算 ``batch_tokens_to_retain``(每条序列的 (token, head) 对数),
在每条序列上对齐 kvpress ``CriticalAdaKVPress.compress``(整段 ``k_len``、与源实现相同的
top-k / scatter 顺序);仅 base 分数来自 compactor_vllm 的 Compactor/SnapKV。
Args:
compression_ctx: 与 ``CompressionContext`` 相同字段即可(duck typing),须含
``batch_tokens_to_retain``;可选 ``critical_ada_epsilon``、
``critical_ada_first_stage_ratio``、``critical_ada_alpha_safeguard``。
"""
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
device = q.device
_, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert D == Dk and Hq % Hk == 0
# 与 non_causal_attn_scores 使用同一 cu(prefill 下即 context.cu_seqlens_q),
# 保证 base_scores 行与 Triton 分段一致;勿与 cu_seqlens_k 混用。
B = cu_seqlens.numel() - 1
G = Hq // Hk
k_lengths = cu_seqlens[1:] - cu_seqlens[:-1]
btr = compression_ctx.batch_tokens_to_retain
assert btr is not None and btr.numel() == B
btr = btr.to(device=device, dtype=torch.int32)
epsilon = compression_ctx.critical_ada_epsilon
first_stage_ratio = compression_ctx.critical_ada_first_stage_ratio
alpha_safeguard = float(compression_ctx.critical_ada_alpha_safeguard)
alpha_safeguard = max(0.0, min(1.0, alpha_safeguard))
if wo_weight.dim() == 2:
hidden_size, _ = wo_weight.shape
wo = wo_weight.transpose(0, 1).view(Hq, D, hidden_size).contiguous()
else:
wo = wo_weight.contiguous()
hidden_size = wo.size(-1)
wo_v_norm = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
if B > 0 and int(k_lengths.max().item()) > 0:
if _USE_WO_L1_REFERENCE_BACKEND:
for b in range(B):
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
if k_end <= k_beg:
continue
v_seg = v[k_beg:k_end, :, :].contiguous()
wo_v_norm[k_beg:k_end, :] = _vwl1_norm_kvpress_reference(
v_seg, wo, Hk, G
)
else:
def grid_wo(META):
max_k_len = int(k_lengths.max().item())
return (B, Hk, triton.cdiv(max_k_len, META["BLOCK_K"]))
_compute_wo_v_l1_kernel[grid_wo](
v,
wo,
cu_seqlens,
wo_v_norm,
*v.stride(),
*wo.stride(),
*wo_v_norm.stride(),
Hk=Hk,
Hq=Hq,
D=D,
HIDDEN=hidden_size,
QUERY_GROUP_SIZE=G,
)
stage1_mask = torch.zeros((N_k, Hk), dtype=torch.int32, device=device)
head_budgets_by_batch: list[Optional[torch.Tensor]] = []
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
head_budgets_by_batch.append(None)
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
keep_pairs = int(btr[b].item())
scores_seg = base_scores[k_beg:k_end, :]
# 与 kvpress 的 n_kept 一致:每头保留 n_kept 个 token
n_kept_tokens = max(1, keep_pairs // Hk)
n_kept_tokens = min(n_kept_tokens, k_len)
# kvpress:topk 在「未改动的」scores 上取索引,scatter 只写在副本上,供 head_budgets 用;
# Stage1 仍用原始 scores_seg(见下)。
working = scores_seg.clone()
n_safe = int(n_kept_tokens * alpha_safeguard)
if n_safe > 0:
nk = min(n_safe, k_len)
for hk in range(Hk):
top_idx = torch.topk(scores_seg[:, hk], nk, sorted=True).indices
working[:, hk].scatter_(0, top_idx, float("inf"))
top_pairs = min(n_kept_tokens * Hk, working.numel())
if top_pairs <= 0:
head_budgets_by_batch.append(None)
continue
top_idx_flat = torch.topk(working.reshape(-1), top_pairs, sorted=False).indices
top_head_idx = top_idx_flat % Hk
head_budgets = torch.bincount(top_head_idx, minlength=Hk).to(torch.int32)
head_budgets_by_batch.append(head_budgets)
# Stage 1:与 kvpress 相同 — 先 topk(..., M1, sorted=True),再每头取前 phase1 个下标
head_selection_budget_1st = (
(head_budgets.to(torch.float32) * float(first_stage_ratio))
.to(torch.int64)
.tolist()
)
M1 = max(head_selection_budget_1st) if head_selection_budget_1st else 0
if M1 > 0:
mk = min(M1, k_len)
for hk in range(Hk):
phase1_budget = int(head_selection_budget_1st[hk])
if phase1_budget <= 0:
continue
full_idx = torch.topk(scores_seg[:, hk], mk, sorted=True).indices
take = min(phase1_budget, mk)
stage1_mask[k_beg + full_idx[:take], hk] = 1
final_scores = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
def grid_fuse(_META):
return (B, Hk)
_critical_ada_fuse_kernel[grid_fuse](
base_scores,
wo_v_norm,
stage1_mask,
cu_seqlens,
final_scores,
*base_scores.stride(),
*wo_v_norm.stride(),
*stage1_mask.stride(),
*final_scores.stride(),
Hk=Hk,
EPSILON=float(epsilon),
)
# Stage 2(kvpress):对融合后分数先 topk(..., M2, sorted=True),再每头取前 budget 个下标置 inf
for b in range(B):
hb = head_budgets_by_batch[b]
if hb is None:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
k_len = k_end - k_beg
if k_len <= 0:
continue
fused_seg = final_scores[k_beg:k_end, :]
M2 = int(hb.max().item())
if M2 <= 0:
continue
mk = min(M2, k_len)
for hk in range(Hk):
budget = int(hb[hk].item())
if budget <= 0:
continue
full_idx = torch.topk(fused_seg[:, hk], mk, sorted=True).indices
take = min(budget, mk)
final_scores[k_beg + full_idx[:take], hk] = float("inf")
masked_key_indices = None
for b in range(B):
k_len = int(k_lengths[b].item())
if k_len == 0:
continue
keep_pairs = int(btr[b].item())
total_pairs = k_len * Hk
if keep_pairs >= total_pairs:
continue
k_beg = int(cu_seqlens[b].item())
k_end = int(cu_seqlens[b + 1].item())
n_prune_pairs = min(total_pairs - keep_pairs, total_pairs)
if n_prune_pairs <= 0:
continue
flat_scores = final_scores[k_beg:k_end, :].reshape(-1)
prune_idx = torch.topk(
-flat_scores, min(n_prune_pairs, flat_scores.numel()), sorted=False
).indices
batch_idx = torch.full_like(prune_idx, b, dtype=torch.int64)
head_idx = prune_idx % Hk
seq_idx = prune_idx // Hk + k_beg
if masked_key_indices is None:
masked_key_indices = (batch_idx, head_idx, seq_idx)
else:
masked_key_indices = (
torch.cat([masked_key_indices[0], batch_idx]),
torch.cat([masked_key_indices[1], head_idx]),
torch.cat([masked_key_indices[2], seq_idx]),
)
if store_stream is not None:
final_scores.record_stream(store_stream)
return final_scores, masked_key_indices
class CriticalAdaKVCompression(BaseCompressionMethod):
"""
以 CompactorCompression 为基分(pre RoPE 杠杆 + post RoPE 非因果融合),
再应用 CriticalAda 两阶段加权;须由 Attention 在 post-RoPE 前注入 ``compression_context.wo_weight``。
"""
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
cc = context.compression_context
base = getattr(cc, "critical_ada_base_scorer", "snapkv") if cc is not None else "compactor"
if str(base).lower() == "snapkv":
return SnapKVCompression.pre_rope_scoring(q, k, v, context)
return CompactorCompression.pre_rope_scoring(q, k, v, context)
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: Optional[torch.Tensor],
context,
) -> Optional[torch.Tensor]:
compression_context = context.compression_context
assert compression_context is not None
base = str(getattr(compression_context, "critical_ada_base_scorer", "compactor")).lower()
if base == "snapkv":
base_scores = SnapKVCompression.post_rope_scoring(q, k, v, pre_rope_scores, context)
else:
# 与 compactor.py 中 CompactorCompression.post_rope_scoring 逐字一致:
# maybe_execute_in_stream(non_causal_attn_scores, q,k,v, cu_seqlens_q, max_seqlen_q, ...)
# 不得改为其它封装,否则与单独使用 COMPACTOR 时分数字不一致。
if context.STORE_STREAM is not None:
torch.cuda.current_stream().wait_stream(context.STORE_STREAM)
base_scores = maybe_execute_in_stream(
non_causal_attn_scores,
q,
k,
v,
context.cu_seqlens_q,
context.max_seqlen_q,
chunk_size=CompactorCompression.chunk_size,
sm_scale=1.0,
normalize=True,
accum_scores=pre_rope_scores,
context_lens=compression_context.context_lens,
protected_first_tokens=compression_context.protected_first_tokens,
protected_last_tokens=compression_context.protected_last_tokens,
accum_blending=0.5,
)
wo_weight = compression_context.wo_weight
if wo_weight is None:
return base_scores
scores, _masked = maybe_execute_in_stream(
critical_ada_key_scores,
q,
k,
v,
wo_weight,
context.cu_seqlens_q,
base_scores,
compression_context,
STORE_STREAM=context.STORE_STREAM,
store_stream=context.STORE_STREAM,
)
return scores
@staticmethod
def prepare_layer(module: torch.nn.Module, device: torch.device, dtype: torch.dtype):
"""可选:预计算并缓存 Wo;实际推理以 Attention.forward 中注入的 ``cc.wo_weight`` 为准。"""
if not hasattr(module, "o_proj") or module.o_proj.weight is None:
return
if not hasattr(module, "num_heads") or not hasattr(module, "head_dim"):
return
wo_raw = module.o_proj.weight.data
hidden_size, _ = wo_raw.shape
Hq = module.num_heads
head_dim = module.head_dim
wo = (
wo_raw.transpose(0, 1)
.view(Hq, head_dim, hidden_size)
.to(device=device, dtype=torch.float32)
)
module._critical_ada_wo_weight = wo
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Compactor-style sparse prefill: Triton varlen attention + paged KV store.
Migrated kernels: ``sparse_varlen_kernel.causal_sparse_varlen_with_cache`` and
``store_kv_cache.prefill_store_topk_kv``.
Layout: MQA uses ``flatten_kv_cache_plane``; GQA/MHA uses head-major flatten
(see ``layout_bridge``).
Execution order note: vLLM runs ``unified_kv_cache_update`` (writes KV) before
``unified_attention_with_output``. Compactor's sparse attention kernel assumes
the paged cache holds only the prefix *before* the current K/V append, while
K_app carries the new tokens. That differs from vLLM's order (cache already
contains the current step after reshape). Therefore ``try_sparse_prefill_forward``
is provided as a reference / future hook and is not invoked from the default
FlashAttention forward path; prefill KV pruning uses ``prefill_store_topk_kv``
in ``do_kv_cache_update_kv_prune`` instead.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import torch
from vllm.forward_context import get_forward_context
from vllm.kvprune.compression.prefill_registry import try_topk_indices_from_registry
from vllm.kvprune.core.compression_bridge import compression_method_id_to_enum
from vllm.kvprune.core.runtime import get_kv_prune_state, layer_index_from_layer_name
from vllm.kvprune.utils.layout_bridge import (
block_table_to_global_page_table,
build_batch_mapping,
build_page_table_head_major,
flatten_kv_cache_head_major,
flatten_kv_cache_plane,
write_head_major_flat_to_interleaved,
)
from vllm.kvprune.attention.sparse_varlen_kernel import causal_sparse_varlen_with_cache
from vllm.kvprune.kv_cache.store_kv_cache import prefill_store_topk_kv
if TYPE_CHECKING:
from vllm.v1.attention.backends.flash_attn import FlashAttentionImpl, FlashAttentionMetadata
_RATIO_EPS = 1.0e-6
def _get_flash_attn_metadata(layer_name: str) -> "FlashAttentionMetadata | None":
try:
fc = get_forward_context()
except AssertionError:
return None
am = fc.attn_metadata
if isinstance(am, list):
if not am:
return None
am = am[0]
meta = am.get(layer_name)
return meta
def try_sparse_prefill_forward(
impl: "FlashAttentionImpl",
layer: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
attn_metadata: "FlashAttentionMetadata",
output: torch.Tensor,
num_actual_tokens: int,
) -> bool:
"""Run compactor ``causal_sparse_varlen_with_cache`` when eligible. Returns True if ran."""
state = get_kv_prune_state()
if state is None or not state.is_prefill:
return False
comp = state.compression_ratio_gpu[: state.num_reqs]
pruned = comp < 1.0 - _RATIO_EPS
if not torch.any(pruned):
return False
mids = state.compression_method_id_gpu[: state.num_reqs]
if torch.unique(mids).numel() > 1:
return False
# Mixed pruned + non-pruned requests: keep default FlashAttention path for now.
if torch.any(pruned) and torch.any(~pruned):
return False
if impl.num_kv_heads != 1:
return False
if impl.kv_cache_dtype.startswith("fp8"):
return False
if impl.alibi_slopes is not None:
return False
if impl.sliding_window != (-1, -1):
return False
d = impl.head_size
if d <= 0 or (d & (d - 1)) != 0:
return False
num_reqs = state.num_reqs
cu = state.query_start_loc[: num_reqs + 1].to(device=query.device, dtype=torch.int32)
seq_lens = attn_metadata.seq_lens[:num_reqs].to(torch.int32)
seqlen_q = cu[1:] - cu[:-1]
cached = seq_lens - seqlen_q
if torch.any(cached < 0):
return False
seq_lens_bh = cached.unsqueeze(1).expand(-1, 1).contiguous()
block_table = attn_metadata.block_table[:num_reqs]
max_batches = block_table.shape[0]
n_lp = block_table.shape[1]
global_page_table = block_table_to_global_page_table(
block_table, impl.num_kv_heads, max_batches=max_batches
)
batch_mapping = build_batch_mapping(num_reqs, query.device)
try:
k_flat, v_flat = flatten_kv_cache_plane(key_cache, value_cache, impl.num_kv_heads)
except ValueError:
return False
page_size = key_cache.shape[1]
if page_size <= 0 or k_flat.shape[0] % page_size != 0:
return False
q3 = query[:num_actual_tokens].view(num_actual_tokens, impl.num_heads, d)
k3 = key[:num_actual_tokens].view(num_actual_tokens, 1, d)
v3 = value[:num_actual_tokens].view(num_actual_tokens, 1, d)
max_seqlen_q = int(attn_metadata.max_query_len)
max_cached = int(seq_lens_bh.max().item()) if seq_lens_bh.numel() else 0
out = causal_sparse_varlen_with_cache(
q3,
k3,
v3,
k_flat,
v_flat,
seq_lens_bh,
global_page_table,
batch_mapping,
cu,
max_seqlen_q=max_seqlen_q,
max_seqlen_k_cache=max_cached,
HKV=1,
PAGE_SIZE=page_size,
sm_scale=None,
)
output[:num_actual_tokens].copy_(out.reshape(num_actual_tokens, impl.num_heads * d))
return True
def _build_tail_topk_indices(
cu_seqlens: torch.Tensor,
num_reqs: int,
hkv: int,
compression_ratio: float | torch.Tensor,
max_sel: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Return (indices [B, max_sel], num_pairs_to_retain [B]) for tail tokens × heads."""
indices = torch.zeros(num_reqs, max_sel, dtype=torch.int32, device=device)
n_pairs = torch.zeros(num_reqs, dtype=torch.int32, device=device)
cu_cpu = cu_seqlens[: num_reqs + 1].detach()
for b in range(num_reqs):
start = int(cu_cpu[b].item())
end = int(cu_cpu[b + 1].item())
chunk_len = end - start
if chunk_len <= 0:
continue
if isinstance(compression_ratio, torch.Tensor):
r_b = float(compression_ratio[b].item())
else:
r_b = compression_ratio
k_tok = max(1, int(round(chunk_len * r_b)))
k_tok = min(k_tok, chunk_len)
pairs: list[int] = []
for tok in range(end - k_tok, end):
for h in range(hkv):
pairs.append(tok * hkv + h)
if len(pairs) >= max_sel:
break
if len(pairs) >= max_sel:
break
n = len(pairs)
if n > 0:
indices[b, :n] = torch.tensor(pairs, dtype=torch.int32, device=device)
n_pairs[b] = n
return indices, n_pairs
def try_prefill_kv_store(
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
) -> bool:
"""Top-k or full compactor prefill KV store; updates per-layer logical lengths."""
state = get_kv_prune_state()
if state is None or not state.is_prefill:
return False
num_reqs = state.num_reqs
comp = state.compression_ratio_gpu[:num_reqs]
pruned = comp < 1.0 - _RATIO_EPS
if not torch.any(pruned):
return False
if torch.any(pruned) and torch.any(~pruned):
return False
mids = state.compression_method_id_gpu[:num_reqs]
if torch.unique(mids).numel() > 1:
return False
meta = _get_flash_attn_metadata(layer.layer_name)
if meta is None:
return False
num_kv_heads = key.shape[1]
d = key.shape[2]
if d <= 0 or (d & (d - 1)) != 0:
return False
key_cache, value_cache = kv_cache.unbind(0)
page_size = key_cache.shape[1]
nb = key_cache.shape[0]
bs = key_cache.shape[1]
head_major = num_kv_heads > 1
try:
if head_major:
k_flat, v_flat = flatten_kv_cache_head_major(key_cache, value_cache)
else:
k_flat, v_flat = flatten_kv_cache_plane(
key_cache, value_cache, num_kv_heads
)
except ValueError:
return False
block_table = meta.block_table[:num_reqs]
max_batches = block_table.shape[0]
if head_major:
global_page_table = build_page_table_head_major(
block_table,
num_kv_heads,
num_blocks=nb,
block_size=bs,
page_size=page_size,
max_batches=max_batches,
)
else:
global_page_table = block_table_to_global_page_table(
block_table, num_kv_heads, max_batches=max_batches
)
batch_mapping = build_batch_mapping(num_reqs, key.device)
cu = state.query_start_loc[: num_reqs + 1].to(device=key.device, dtype=torch.int32)
seq_lens = meta.seq_lens[:num_reqs].to(torch.int32)
seqlen_q = cu[1:] - cu[:-1]
cached = (seq_lens - seqlen_q).unsqueeze(1).expand(-1, num_kv_heads).contiguous()
layer_idx = layer_index_from_layer_name(layer.layer_name)
max_seqlen_k = int(seqlen_q.max().item()) if seqlen_q.numel() else 0
max_sel = min(max_seqlen_k * num_kv_heads, 8192)
max_sel = max(max_sel, 1)
mid = int(state.compression_method_id_gpu[0].item())
method_enum = compression_method_id_to_enum(mid)
registry_out = try_topk_indices_from_registry(
method_enum, key, value, cu, num_reqs, comp, max_sel, key.device
)
if registry_out is not None:
indices, n_pairs = registry_out
else:
indices, n_pairs = _build_tail_topk_indices(
cu, num_reqs, num_kv_heads, comp, max_sel, key.device
)
bh = cached.clone()
prefill_store_topk_kv(
new_keys=key,
new_vals=value,
indices_topk=indices,
num_tokens_to_retain=n_pairs,
page_table=global_page_table,
batch_mapping=batch_mapping,
bh_lens=bh,
k_cache=k_flat,
v_cache=v_flat,
PAGE_SIZE=page_size,
PAD_TO_PAGE_SIZE=False,
cu_seqlens_k=None,
)
if head_major:
write_head_major_flat_to_interleaved(
k_flat, v_flat, key_cache, value_cache
)
new_lens = bh.to(torch.int32)
if state.logical_seq_lens_gpu.dim() == 3:
state.logical_seq_lens_gpu[layer_idx, :num_reqs, :] = new_lens
else:
state.logical_seq_lens_gpu[layer_idx, :num_reqs] = new_lens.max(
dim=1
).values
return True
__all__ = [
"try_sparse_prefill_forward",
"try_prefill_kv_store",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Map COMPRESSION_REGISTRY scoring to prefill top-k indices (with tail fallback)."""
from __future__ import annotations
import logging
import torch
from vllm.kvprune.compression import COMPRESSION_REGISTRY
from vllm.kvprune.compression.compression_config import CompressionMethod
from vllm.kvprune.utils.context import CompressionContext, Context
logger = logging.getLogger(__name__)
def _scores_to_topk_pair_indices(
cu_seqlens: torch.Tensor,
num_reqs: int,
hkv: int,
scores: torch.Tensor,
compression_ratio: float | torch.Tensor,
max_sel: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Select (token, head) pairs with highest scores per request up to budget."""
if scores.dim() == 1:
scores = scores.unsqueeze(-1).expand(-1, hkv)
elif scores.dim() > 2:
scores = scores.reshape(scores.shape[0], -1)[:, :hkv]
indices = torch.zeros(num_reqs, max_sel, dtype=torch.int32, device=device)
n_pairs = torch.zeros(num_reqs, dtype=torch.int32, device=device)
cu_cpu = cu_seqlens[: num_reqs + 1].detach()
for b in range(num_reqs):
start = int(cu_cpu[b].item())
end = int(cu_cpu[b + 1].item())
chunk_len = end - start
if chunk_len <= 0:
continue
if isinstance(compression_ratio, torch.Tensor):
r_b = float(compression_ratio[b].item())
else:
r_b = compression_ratio
k_tok = max(1, int(round(chunk_len * r_b)))
k_tok = min(k_tok, chunk_len)
budget = min(k_tok * hkv, max_sel)
flat_scores: list[tuple[float, int]] = []
for tok in range(start, end):
for h in range(hkv):
if scores.dim() == 2:
s = float(scores[tok, h].item())
else:
s = float(scores[tok].item())
idx = tok * hkv + h
flat_scores.append((s, idx))
flat_scores.sort(key=lambda x: -x[0])
n = min(budget, len(flat_scores))
if n > 0:
chosen = [x[1] for x in flat_scores[:n]]
indices[b, :n] = torch.tensor(chosen, dtype=torch.int32, device=device)
n_pairs[b] = n
return indices, n_pairs
def try_topk_indices_from_registry(
method: CompressionMethod,
key: torch.Tensor,
value: torch.Tensor,
cu: torch.Tensor,
num_reqs: int,
compression_ratio: torch.Tensor,
max_sel: int,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor] | None:
"""Return (indices, n_pairs) using registry scoring, or None to use tail fallback."""
if method == CompressionMethod.NONE:
return None
num_kv_heads = key.shape[1]
n_tokens, hkv, d = key.shape
if n_tokens <= 0 or hkv <= 0:
return None
k_flat = key.reshape(n_tokens, hkv, d)
v_flat = value.reshape(n_tokens, hkv, d)
context_lens = []
cu_cpu = cu[: num_reqs + 1].detach().cpu()
for b in range(num_reqs):
context_lens.append(int(cu_cpu[b + 1].item() - cu_cpu[b].item()))
max_seqlen_q = int((cu_cpu[1 : num_reqs + 1] - cu_cpu[:num_reqs]).max().item())
if method == CompressionMethod.COMPACTOR:
try:
k_proj = min(64, d)
phi = torch.randn(d, k_proj, device=key.device, dtype=torch.float32)
cc = CompressionContext(
compression_method=CompressionMethod.COMPACTOR,
context_lens=context_lens,
PHI=phi,
compression_chunk_size=512,
protected_first_tokens=[0] * num_reqs,
protected_last_tokens=[0] * num_reqs,
)
ctx = Context(
is_prefill=True,
do_compression=True,
cu_seqlens_q=cu,
max_seqlen_q=max_seqlen_q,
compression_context=cc,
)
cls = COMPRESSION_REGISTRY[CompressionMethod.COMPACTOR]
q_dummy = torch.zeros_like(k_flat)
scores = cls.pre_rope_scoring(
q_dummy,
k_flat,
v_flat,
context=ctx,
)
if scores is None:
return None
return _scores_to_topk_pair_indices(
cu, num_reqs, hkv, scores, compression_ratio, max_sel, device
)
except Exception:
logger.debug("Compactor pre_rope scoring failed; using tail fallback", exc_info=True)
return None
if method == CompressionMethod.CRITICALADAKV:
try:
k_proj = min(64, d)
phi = torch.randn(d, k_proj, device=key.device, dtype=torch.float32)
cc = CompressionContext(
compression_method=CompressionMethod.CRITICALADAKV,
context_lens=context_lens,
PHI=phi,
compression_chunk_size=512,
protected_first_tokens=[0] * num_reqs,
protected_last_tokens=[0] * num_reqs,
)
ctx = Context(
is_prefill=True,
do_compression=True,
cu_seqlens_q=cu,
max_seqlen_q=max_seqlen_q,
compression_context=cc,
)
cls_ada = COMPRESSION_REGISTRY[CompressionMethod.CRITICALADAKV]
q_dummy = torch.zeros_like(k_flat)
pre_scores = cls_ada.pre_rope_scoring(
q_dummy, k_flat, v_flat, context=ctx
)
scores = cls_ada.post_rope_scoring(
q_dummy, k_flat, v_flat, pre_scores, context=ctx
)
if scores is None:
return None
return _scores_to_topk_pair_indices(
cu, num_reqs, hkv, scores, compression_ratio, max_sel, device
)
except Exception:
logger.debug(
"CriticalAdaKV registry path failed; using tail fallback", exc_info=True
)
return None
if method == CompressionMethod.SNAPKV:
try:
cc = CompressionContext(compression_method=CompressionMethod.SNAPKV)
ctx = Context(
is_prefill=True,
do_compression=True,
cu_seqlens_q=cu,
cu_seqlens_k=cu,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_q,
compression_context=cc,
)
cls = COMPRESSION_REGISTRY[CompressionMethod.SNAPKV]
q_dummy = torch.zeros_like(k_flat)
scores = cls.post_rope_scoring(
q_dummy,
k_flat,
v_flat,
None,
context=ctx,
)
if scores is None:
return None
return _scores_to_topk_pair_indices(
cu, num_reqs, hkv, scores, compression_ratio, max_sel, device
)
except Exception:
logger.debug("SnapKV registry path failed; using tail fallback", exc_info=True)
return None
return None
__all__ = ["try_topk_indices_from_registry"]
import math
from typing import Optional
import torch
import triton
from triton import language as tl
from vllm.kvprune.compression.common import BaseCompressionMethod
from vllm.kvprune.utils.helpers import maybe_execute_in_stream
from vllm.kvprune.utils.triton_compat import autotune as triton_autotune
# SnapKV defaults aligned with kvpress `SnapKVPress` (snapkv_press.py).
DEFAULT_SNAPKV_WINDOW_SIZE = 64
DEFAULT_SNAPKV_KERNEL_SIZE = 5
class SnapKVCompression(BaseCompressionMethod):
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
return None
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
scores = maybe_execute_in_stream(
query_aware_key_scores,
q,
k,
context.cu_seqlens_q,
context.cu_seqlens_k,
w=DEFAULT_SNAPKV_WINDOW_SIZE,
kernel_size=DEFAULT_SNAPKV_KERNEL_SIZE,
STORE_STREAM=context.STORE_STREAM,
)
return scores
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
)
for bq in [32, 64]
for bk in [32, 64]
for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
cache_results=True,
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b, # int32 pointers
out_m,
out_S, # [B, Hk, ROWS_MAX] float32
LOGITS, # [Nk, Hk, ROWS_MAX] float32
sm_scale, # float
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
# program ids
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2) # row-tile id
# batch segment bounds
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
# rows for this (b,hk)
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale = sm_scale * 1.4426950408889634
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
# map row -> (q_idx, hq_local)
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (
Q
+ q_idx[:, None] * STRIDE_Q_NQ
+ hq_glob[:, None] * STRIDE_Q_HQ
+ offs_d[None, :]
)
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
# Full-sequence causal attention (matches kvpress softmax), then use prefix columns only.
for ks in tl.range(k_beg, k_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_end
k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
s = tl.where(kmask[None, :], s, -float("inf"))
# Causal: key j only if j <= q_idx (same as kvpress triu mask on the window×k_len grid).
causal_ok = nk[None, :] <= q_idx[:, None]
s = tl.where(causal_ok, s, -float("inf"))
# store prefix logits only (for marginal probs on prefix keys)
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
store_mask = kmask & (nk < k_eff_end)
tl.store(log_ptrs, s.T, mask=store_mask[:, None] & row_mask[None, :])
# log2 streaming LSE over all keys in [k_beg, k_end) (after causal mask)
cur_max = tl.max(s, 1) # [BQ]
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
# store m,S for these rows
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton_autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64]
for bk in [32, 64, 128]
],
key=["HK", "HQ"],
cache_results=True,
)
@triton.jit
def _prefix_probs_kernel(
cu_k,
w_b,
in_m,
in_S, # [B, Hk, ROWS_MAX] f32
LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits (prefix keys only)
PROBS, # [Nk, Hk, ROWS_MAX] f32 — per-row prefix marginal probs
#
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_PB_NK,
STRIDE_PB_HK,
STRIDE_PB_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
m = tl.load(
m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"),
)
S = tl.load(
S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
s_T = tl.load(
log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
) # [BK, BQ]
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
prob_ptrs = (
PROBS
+ nk[:, None] * STRIDE_PB_NK
+ hk * STRIDE_PB_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_PB_R
)
tl.store(prob_ptrs, probs_T, mask=kmask[:, None] & rmask[None, :])
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
cache_results=True,
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT, # [Nk, Hk], float32
cu_k,
w_b, # [B+1], [B] int32
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr, # Hk
EPS: tl.constexpr, # e.g., 1e-12
BLOCK_K: tl.constexpr, # e.g., 128
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
@triton_autotune(
configs=[triton.Config({"BLOCK_T": bt}) for bt in [32, 64, 128, 256]],
key=["KERNEL_SIZE"],
cache_results=True,
)
@triton.jit
def _snapkv_avg_pool1d_kernel(
IN,
OUT,
Lp,
STRIDE_IN_C,
STRIDE_IN_L,
STRIDE_OUT_C,
STRIDE_OUT_L,
KERNEL_SIZE: tl.constexpr,
PAD: tl.constexpr,
BLOCK_T: tl.constexpr,
):
"""
Symmetric 1D average pool on the last dimension, matching
`F.avg_pool1d(x, kernel_size=K, padding=K//2, stride=1)` on `x` shaped [C, Lp]
(equivalent to PyTorch [C, 1, Lp] avg_pool1d with divisor = kernel size).
"""
c = tl.program_id(0)
t0 = tl.program_id(1) * BLOCK_T + tl.arange(0, BLOCK_T)
mask = t0 < Lp
acc = tl.zeros([BLOCK_T], dtype=tl.float32)
for j in tl.static_range(KERNEL_SIZE):
idx = t0 - PAD + j
valid = (idx >= 0) & (idx < Lp)
ptrs = IN + c * STRIDE_IN_C + idx * STRIDE_IN_L
v = tl.load(ptrs, mask=valid & mask, other=0.0).to(tl.float32)
acc += v
acc = acc / tl.cast(KERNEL_SIZE, tl.float32)
out_ptrs = OUT + c * STRIDE_OUT_C + t0 * STRIDE_OUT_L
tl.store(out_ptrs, acc, mask=mask)
def _snapkv_avg_pool1d_triton(x: torch.Tensor, kernel_size: int) -> torch.Tensor:
"""
kvpress-equivalent smoothing: same as `F.avg_pool1d` on [Hk*G, 1, Lp].
`x` must be float32 and contiguous along Lp (shape [Hk, G, Lp]).
"""
assert x.dtype == torch.float32
Hk, G, Lp = x.shape
if Lp == 0:
return x
pad = kernel_size // 2
x2 = x.reshape(Hk * G, Lp).contiguous()
out = torch.empty_like(x2)
C = Hk * G
si_c, si_l = x2.stride()
so_c, so_l = out.stride()
def grid(meta):
return (C, triton.cdiv(Lp, meta["BLOCK_T"]))
_snapkv_avg_pool1d_kernel[grid](
x2,
out,
Lp,
si_c,
si_l,
so_c,
so_l,
KERNEL_SIZE=kernel_size,
PAD=pad,
)
return out.view(Hk, G, Lp)
def _snapkv_kvpress_epilogue(
probs_buf: torch.Tensor,
out: torch.Tensor,
cu_seqlens_k: torch.Tensor,
w: torch.Tensor,
G: int,
Hk: int,
kernel_size: int,
) -> None:
"""
Match kvpress SnapKV order: mean over window queries → symmetric avg_pool1d
→ mean over GQA groups → pad tail with global max of prefix scores.
"""
B = cu_seqlens_k.numel() - 1
for b in range(B):
k_beg = int(cu_seqlens_k[b].item())
k_end = int(cu_seqlens_k[b + 1].item())
win = int(w[b].item())
k_eff_end = k_end - win
if win <= 0 or k_eff_end <= k_beg:
continue
Lp = k_eff_end - k_beg
rows_b = win * G
p = probs_buf[k_beg:k_eff_end, :, :rows_b]
# [Lp, Hk, win, G] — rows are (q_off, g) order per Triton row layout
x = p.view(Lp, Hk, win, G).mean(dim=2)
x = x.permute(1, 2, 0).contiguous() # [Hk, G, Lp]
x = _snapkv_avg_pool1d_triton(x, kernel_size)
x = x.mean(dim=1)
seg = x.permute(1, 0).contiguous()
out[k_beg:k_eff_end, :] = seg
pad_val = seg.max()
out[k_eff_end:k_end, :] = pad_val
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1], int32
cu_seqlens_k: torch.Tensor, # [B+1], int32
w: torch.Tensor | int, # [B], int32
sm_scale: float = None, # defaults to 1/sqrt(D)
*,
kernel_size: int = DEFAULT_SNAPKV_KERNEL_SIZE,
accum_scores: torch.Tensor = None,
accum_blending: float = None,
normalize: bool = False,
) -> Optional[torch.Tensor]:
assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = cu_seqlens_q.numel() - 1
assert B == cu_seqlens_k.numel() - 1
G = Hq // Hk
if type(w) is int:
max_w = w
w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
else:
max_w = int(w.max().item())
assert w.numel() == B
ROWS_MAX = max_w * G
if ROWS_MAX == 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
probs_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
# strides
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_PB_NK, STRIDE_PB_HK, STRIDE_PB_R = probs_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(META):
return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
_lse_and_store_logits_kernel[grid](
q,
k,
cu_seqlens_q,
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=Hq // Hk,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=ROWS_MAX,
)
_prefix_probs_kernel[(B, Hk)](
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
probs_buf,
QUERY_GROUP_SIZE=Hq // Hk,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_PB_NK=STRIDE_PB_NK,
STRIDE_PB_HK=STRIDE_PB_HK,
STRIDE_PB_R=STRIDE_PB_R,
)
_snapkv_kvpress_epilogue(
probs_buf, out, cu_seqlens_k, w, G, Hk, kernel_size
)
if normalize:
_zscore_per_batch_epilogue[(B,)](
out,
cu_seqlens_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
if accum_scores is not None:
if accum_blending is not None:
accum_scores.mul_(accum_blending)
accum_scores.add_(out)
return accum_scores
else:
return out
import math
from typing import Optional
import torch
import triton
from triton import language as tl
from compactor_vllm.compression.common import BaseCompressionMethod
from compactor_vllm.utils.helpers import maybe_execute_in_stream
from compactor_vllm.utils.triton_compat import autotune as triton_autotune
class SnapKVCompression(BaseCompressionMethod):
@staticmethod
def pre_rope_scoring(
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context
) -> Optional[torch.Tensor]:
return None
@staticmethod
def post_rope_scoring(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
pre_rope_scores: torch.Tensor,
context,
) -> Optional[torch.Tensor]:
scores = maybe_execute_in_stream(
query_aware_key_scores,
q,
k,
context.cu_seqlens_q,
context.cu_seqlens_k,
w=32,
STORE_STREAM=context.STORE_STREAM,
)
return scores
@triton_autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk}, num_warps=num_warps, num_stages=num_stages
)
for bq in [32, 64]
for bk in [32, 64]
for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
cache_results=True,
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b, # int32 pointers
out_m,
out_S, # [B, Hk, ROWS_MAX] float32
LOGITS, # [Nk, Hk, ROWS_MAX] float32
sm_scale, # float
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
# program ids
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2) # row-tile id
# batch segment bounds
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
# rows for this (b,hk)
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
# exp(x) = exp2(x * 1/ln2)
qk_scale = sm_scale * 1.4426950408889634
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
# map row -> (q_idx, hq_local)
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (
Q
+ q_idx[:, None] * STRIDE_Q_NQ
+ hq_glob[:, None] * STRIDE_Q_HQ
+ offs_d[None, :]
)
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
k_ptrs = K + nk[:, None] * STRIDE_K_NK + hk * STRIDE_K_HK + offs_d[None, :]
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0) # [BK, D]
s = tl.dot(q_rows, k_blk.T) * qk_scale # [BQ, BK]
s = tl.where(kmask[None, :], s, -float("inf"))
# store into LOGITS[nk, hk, row] -> [BK, BQ]
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
# log2 streaming LSE update
cur_max = tl.max(s, 1) # [BQ]
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
# store m,S for these rows
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton_autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64]
for bk in [32, 64, 128]
],
key=["HK", "HQ"],
cache_results=True,
)
@triton.jit
def _scores_from_logits_kernel(
cu_k,
w_b,
in_m,
in_S, # [B, Hk, ROWS_MAX] f32
LOGITS, # [Nk, Hk, ROWS_MAX] f32, base-2 logits
OUT, # [Nk, Hk] f32
#
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
#
DO_POOL: tl.constexpr, # set True to enable in-place avg pool
KPOOL: tl.constexpr, # kernel size for avg pool (stride=1)
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
# === scores over computed region ===
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
scores = tl.zeros([BLOCK_K], dtype=tl.float32)
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
# load m, S for rows
m_ptr = in_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_ptr = in_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
m = tl.load(
m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"),
)
S = tl.load(
S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R, mask=rmask, other=0.0
)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
# load stored logits^T: [BK, BQ]
log_ptrs = (
LOGITS
+ nk[:, None] * STRIDE_LG_NK
+ hk * STRIDE_LG_HK
+ (row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R
)
s_T = tl.load(
log_ptrs, mask=kmask[:, None] & rmask[None, :], other=-float("inf")
) # [BK, BQ]
# probs^T = exp2(s_T - m) / S, sum over rows
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
scores += tl.sum(probs_T, 1) # [BK]
if DO_POOL and (KPOOL > 1):
i = tl.arange(0, BLOCK_K)[:, None]
j = tl.arange(0, BLOCK_K)[None, :]
band = (j <= i) & ((i - j) < KPOOL)
band = band & kmask[None, :]
# sum within band
sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1) # [BK]
denom = tl.sum(band, 1).to(tl.float32) # [BK]
denom = tl.where(denom > 0, denom, 1.0)
scores = sums / denom
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, scores, mask=kmask)
pad_beg = k_eff_end
pad_end = k_end
if pad_end > pad_beg:
for ks in tl.range(pad_beg, pad_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < pad_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(
out_ptrs, tl.full([BLOCK_K], float("inf"), dtype=tl.float32), mask=kmask
)
@triton_autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
cache_results=True,
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT, # [Nk, Hk], float32
cu_k,
w_b, # [B+1], [B] int32
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr, # Hk
EPS: tl.constexpr, # e.g., 1e-12
BLOCK_K: tl.constexpr, # e.g., 128
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1], int32
cu_seqlens_k: torch.Tensor, # [B+1], int32
w: torch.Tensor | int, # [B], int32
sm_scale: float = None, # defaults to 1/sqrt(D)
*,
accum_scores: torch.Tensor = None,
accum_blending: float = None,
normalize: bool = False,
) -> Optional[torch.Tensor]:
assert q.stride(-1) == 1 and k.stride(-1) == 1, "last dim must be contiguous"
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
assert (Hq % Hk) == 0, "Hq must be a multiple of Hk"
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = cu_seqlens_q.numel() - 1
assert B == cu_seqlens_k.numel() - 1
G = Hq // Hk
if type(w) is int:
max_w = w
w = torch.full((B,), fill_value=w, device=device, dtype=torch.int32)
else:
max_w = int(w.max().item())
assert w.numel() == B
ROWS_MAX = max_w * G
if ROWS_MAX == 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
out = torch.empty((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
S_scratch = torch.empty((B, Hk, ROWS_MAX), dtype=torch.float32, device=device)
logits_buf = torch.empty((N_k, Hk, ROWS_MAX), dtype=torch.float32, device=device)
# strides
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(META):
return B, Hk, triton.cdiv(ROWS_MAX, META["BLOCK_Q"])
_lse_and_store_logits_kernel[grid](
q,
k,
cu_seqlens_q,
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=Hq // Hk,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=ROWS_MAX,
)
_scores_from_logits_kernel[(B, Hk)](
cu_seqlens_k,
w,
m_scratch,
S_scratch,
logits_buf,
out,
QUERY_GROUP_SIZE=Hq // Hk,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_OUT_NK=STRIDE_OUT_NK,
STRIDE_OUT_HK=STRIDE_OUT_HK,
DO_POOL=True,
KPOOL=5,
)
if normalize:
_zscore_per_batch_epilogue[(B,)](
out,
cu_seqlens_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
if accum_scores is not None:
if accum_blending is not None:
accum_scores.mul_(accum_blending)
accum_scores.add_(out)
return accum_scores
else:
return out
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Engine / sampling / kernel constants (compactor-compatible)."""
from vllm.kvprune.config.constants import RESERVED_BATCH, TRITON_RESERVED_BATCH
__all__ = ["RESERVED_BATCH", "TRITON_RESERVED_BATCH"]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
RESERVED_BATCH = 0
# NOTE: Triton `tl.constexpr` is intended for use in kernel signatures/annotations.
# Some Triton builds reject passing `tl.constexpr(...)` objects as constexpr values.
# Keep the runtime value as a plain int and let kernel signatures declare constexpr.
TRITON_RESERVED_BATCH = RESERVED_BATCH
import os
from dataclasses import dataclass
from enum import Enum, auto
from typing import List, Optional
from transformers import AutoConfig
class AttentionBackend(Enum):
"""Legacy coarse backend toggle (prefer :class:`KvpruneAttentionSchedule`)."""
FLASH_ATTENTION = auto()
COMPACTOR_TRITON = auto()
class KvpruneAttentionSchedule(Enum):
"""FlashAttention vs Triton split for prefill / decode (KV **writes** stay Triton)."""
# Default: FA varlen prefill; decode uses ``head_sparse_decode_attention`` (Triton).
FA_PREFILL_TRITON_DECODE = auto()
# Prefill attention uses ``causal_sparse_varlen_with_cache`` (Triton); decode Triton.
TRITON_PREFILL_TRITON_DECODE = auto()
# "PDFA": FA prefill + FA decode; paged KV **storage** (incl. pruned top-k) unchanged.
PDFA = auto()
@dataclass
class LLMConfig:
"""Configuration for the :class:`LLM` engine.
Parameters
----------
model : str
Hugging Face model identifier (e.g. ``"meta-llama/Meta-Llama-3-8B"``) or
a local model name that can be resolved by
:func:`transformers.AutoConfig.from_pretrained`.
path : str, optional
Local directory containing the model weights. If ``None``, the engine
will attempt to resolve a local snapshot for ``model`` using
:func:`huggingface_hub.snapshot_download`.
max_num_seqs : int, default 256
Upper bound on the number of concurrent batches that the scheduler and
KV-cache manager are allowed to handle. This affects the size of the
page table and some internal buffers.
max_model_len : int, default 40960
Maximum context length (in tokens) that the engine will allocate KV cache
and CUDA graphs for. During initialization this value is clamped to
``hf_config.max_position_embeddings`` for the chosen model.
gpu_memory_utilization : float, default 0.9
Fraction of the total GPU memory that may be used for KV cache and model
activations. Values should be in ``(0, 1]``. If this budget is too small,
the KV-cache manager may raise an error at warmup time due
to insufficient memory.
tensor_parallel_size : int, default 1
Number of tensor-parallel workers to shard the model
across. Must be between 1 and 8, and must evenly divide the model's
number of key/value heads.
enforce_eager : bool, default False
If ``True``, disable CUDA graph capture and always run the model in
eager mode during decoding. This reduces throughput. When ``False``,
the engine will capture and reuse CUDA graphs for supported
batch sizes and sequence lengths.
hf_config : transformers.AutoConfig, optional
Pre-loaded Hugging Face configuration for the model. If ``None``,
it will then be populated automatically based on ``model``.
eos : int, default -1
Primary stop token id (warmup / single-id paths). If ``-1``, the
:class:`LLM` constructor fills this and :attr:`eos_token_ids` from the
tokenizer.
eos_token_ids : list of int, optional
All token ids that terminate generation (e.g. HF tokenizers may expose
``eos_token_id`` as a list for chat models). If ``None``, inferred in
:class:`LLM` from the tokenizer and model type.
kvcache_page_size : int, default 128
Number of tokens stored in a single KV-cache page. Smaller pages improve
allocation flexibility but increase page-table overhead; larger pages
reduce overhead but have coarser granularity.
leverage_sketch_size : int, default 48
Sketch dimension used by the Compactor leverage-score estimator.
attention_schedule : KvpruneAttentionSchedule, default FA_PREFILL_TRITON_DECODE
Which **attention** implementation runs on prefill vs decode. KV **writes**
(``prefill_store_*``, ``decode_store_kv``, pruned top-k) always use the
existing Triton store kernels. Env ``VLLM_KVPRUNE_ATTENTION_SCHEDULE`` uses
short names: ``fa_triton`` (default), ``pdtriton``, ``pdfa``. Enum values:
``FA_PREFILL_TRITON_DECODE`` — FA prefill, Triton decode;
``TRITON_PREFILL_TRITON_DECODE`` — Triton prefill + decode;
``PDFA`` — FA prefill + FA decode (still Triton KV I/O).
attention_backend : AttentionBackend, optional
Deprecated. Ignored if ``attention_schedule`` is set; otherwise mapped
for backward compatibility.
"""
model: str
path: Optional[str] = None
nccl_port: Optional[int] = 1218
max_num_seqs: int = 256
max_model_len: int = 40960
gpu_memory_utilization: float = 0.9
tensor_parallel_size: int = 1
enforce_eager: bool = False
hf_config: AutoConfig | None = None
eos: int = -1
eos_token_ids: Optional[List[int]] = None
kvcache_page_size: int = 128
leverage_sketch_size: int = 48
attention_schedule: KvpruneAttentionSchedule = (
KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
)
attention_backend: AttentionBackend | None = None
show_progress_bar: bool = True
def __post_init__(self):
if self.attention_backend is not None:
if self.attention_backend == AttentionBackend.FLASH_ATTENTION:
self.attention_schedule = KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
else:
self.attention_schedule = (
KvpruneAttentionSchedule.TRITON_PREFILL_TRITON_DECODE
)
if self.path is not None and not os.path.isdir(self.path):
raise NotADirectoryError(f"Engine config dir {self.path} does not exist")
if self.tensor_parallel_size <= 0 or self.tensor_parallel_size > 8:
assert 1 <= self.tensor_parallel_size <= 8
raise ValueError("tensor_parallel_size must be >= 1 and <= 8")
if self.hf_config is None:
self.hf_config = AutoConfig.from_pretrained(self.model)
self.max_model_len = min(
self.max_model_len, self.hf_config.max_position_embeddings
)
from dataclasses import dataclass
@dataclass
class SamplingParams:
temperature: float = 1.0
max_new_tokens: int = 256
def __post_init__(self):
if self.temperature < 0:
raise ValueError("Temperature cannot be negative")
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Core: compactor ``LLMEngine`` stack (``llm_engine``, ``scheduler``, …) plus helpers
(``runtime``, ``flash_integration``, ``block_budget``) used **inside** the compactor path.
v1 does not import these; use :meth:`vllm.LLM.generate` with ``compression=`` for the
``LLM`` + compactor integration.
"""
from vllm.kvprune.core.block_budget import (
TailReclaimHint,
build_tail_reclaim_hint,
tail_blocks_if_logical_shorter,
)
from vllm.kvprune.core.compression_bridge import (
VALID_ALIASES_FOR_SAMPLING,
compression_method_id_to_enum,
compression_method_str_to_id,
)
from vllm.kvprune.core.flash_integration import (
do_kv_cache_update_kv_prune,
merge_seq_lens_with_kv_prune,
)
from vllm.kvprune.core.runtime import (
KVPruneForwardState,
build_kv_prune_forward_state,
get_kv_prune_state,
layer_index_from_layer_name,
)
__all__ = [
"KVPruneForwardState",
"TailReclaimHint",
"VALID_ALIASES_FOR_SAMPLING",
"build_kv_prune_forward_state",
"build_tail_reclaim_hint",
"compression_method_id_to_enum",
"compression_method_str_to_id",
"do_kv_cache_update_kv_prune",
"get_kv_prune_state",
"layer_index_from_layer_name",
"merge_seq_lens_with_kv_prune",
"tail_blocks_if_logical_shorter",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Block budget helpers for compactor KV pruning (logical vs physical length).
Used by the **compactor** ``LLMEngine`` path (``PagedKVCache`` / logical lengths),
not by v1's scheduler. The helpers compare logical KV length to a physical token
count and return how many full tail blocks can be reclaimed when logical shrinks.
"""
from __future__ import annotations
from dataclasses import dataclass
@dataclass(frozen=True)
class TailReclaimHint:
"""How many tail blocks could be freed if logical KV shrinks below allocation."""
request_id: str
allocated_tokens: int
logical_tokens: int
block_size: int
reclaimable_tail_blocks: int
def tail_blocks_if_logical_shorter(
allocated_tokens: int,
logical_tokens: int,
block_size: int,
) -> int:
"""Return count of fully-unused tail blocks when ``logical < allocated``.
Block-granular: only counts whole blocks past the last block that still
contains a retained logical token index.
"""
if block_size <= 0:
return 0
if logical_tokens >= allocated_tokens:
return 0
# Last logical token occupies block index floor((logical-1)/bs) if logical>0
if logical_tokens <= 0:
return (allocated_tokens + block_size - 1) // block_size
last_logical_block = (logical_tokens - 1) // block_size
last_alloc_block = (allocated_tokens - 1) // block_size
return max(0, last_alloc_block - last_logical_block)
def build_tail_reclaim_hint(
request_id: str,
allocated_tokens: int,
logical_tokens: int,
block_size: int,
) -> TailReclaimHint:
n = tail_blocks_if_logical_shorter(allocated_tokens, logical_tokens, block_size)
return TailReclaimHint(
request_id=request_id,
allocated_tokens=allocated_tokens,
logical_tokens=logical_tokens,
block_size=block_size,
reclaimable_tail_blocks=n,
)
__all__ = [
"TailReclaimHint",
"build_tail_reclaim_hint",
"tail_blocks_if_logical_shorter",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Map compression method strings (e.g. from :class:`~vllm.kvprune.integration.CompressionParams`) to kvprune GPU / enum IDs."""
from __future__ import annotations
from vllm.kvprune.compression.compression_config import CompressionMethod
# IDs stored on device [num_reqs_padded] (int32). Order is stable for kernels.
COMPRESSION_METHOD_ID_NONE = 0
COMPRESSION_METHOD_ID_CRITICALADAKV = 1
COMPRESSION_METHOD_ID_COMPACTOR = 2
COMPRESSION_METHOD_ID_SNAPKV = 3
# Aliases accepted for method strings (case-insensitive after strip).
VALID_ALIASES_FOR_SAMPLING: frozenset[str] = frozenset(
{"none", "criticaladakv", "compactor", "snapkv"}
)
_STR_TO_ID: dict[str, int] = {
"none": COMPRESSION_METHOD_ID_NONE,
"criticaladakv": COMPRESSION_METHOD_ID_CRITICALADAKV,
"compactor": COMPRESSION_METHOD_ID_COMPACTOR,
"snapkv": COMPRESSION_METHOD_ID_SNAPKV,
}
_ID_TO_COMPRESSION_METHOD: dict[int, CompressionMethod] = {
COMPRESSION_METHOD_ID_NONE: CompressionMethod.NONE,
COMPRESSION_METHOD_ID_CRITICALADAKV: CompressionMethod.CRITICALADAKV,
COMPRESSION_METHOD_ID_COMPACTOR: CompressionMethod.COMPACTOR,
COMPRESSION_METHOD_ID_SNAPKV: CompressionMethod.SNAPKV,
}
def compression_method_str_to_id(s: str) -> int:
"""Normalize and map user string to a stable int id (0..3)."""
key = (s or "none").strip().lower()
if key not in _STR_TO_ID:
raise ValueError(
f"Unknown compression_method {s!r}; expected one of "
f"{sorted(VALID_ALIASES_FOR_SAMPLING)}"
)
return _STR_TO_ID[key]
def compression_method_id_to_enum(method_id: int) -> CompressionMethod:
if method_id not in _ID_TO_COMPRESSION_METHOD:
return CompressionMethod.NONE
return _ID_TO_COMPRESSION_METHOD[method_id]
__all__ = [
"COMPRESSION_METHOD_ID_NONE",
"COMPRESSION_METHOD_ID_CRITICALADAKV",
"COMPRESSION_METHOD_ID_COMPACTOR",
"COMPRESSION_METHOD_ID_SNAPKV",
"VALID_ALIASES_FOR_SAMPLING",
"compression_method_id_to_enum",
"compression_method_str_to_id",
]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""FlashAttention + KV cache hooks for kvprune."""
from __future__ import annotations
import torch
from vllm.kvprune.core.runtime import KVPruneForwardState, get_kv_prune_state
_RATIO_ONE = 1.0 - 1e-6
def merge_seq_lens_with_kv_prune(
base_seq_lens: torch.Tensor,
layer_name: str,
max_query_len: int,
) -> torch.Tensor:
"""Blend scheduler seq_lens with per-layer logical lengths when pruning."""
state = get_kv_prune_state()
if state is None:
return base_seq_lens
# Prefill: only scheduler lengths are reliable unless compactor store ran for
# every layer (try_prefill_kv_store); when pruning is requested but ineligible
# (e.g. unsupported dtype), logical buffers may still be zero — do not override.
if max_query_len > 1:
return base_seq_lens
layer_idx = _layer_idx(layer_name)
num_reqs = state.num_reqs
comp = state.compression_ratio_gpu[:num_reqs]
logical = state.logical_seq_lens_gpu[layer_idx, :num_reqs]
if logical.dim() == 2:
logical = logical.max(dim=-1).values
out = base_seq_lens.clone()
use_logical = comp < _RATIO_ONE
out[:num_reqs] = torch.where(
use_logical,
logical.to(out.dtype),
base_seq_lens[:num_reqs],
)
return out
def _layer_idx(layer_name: str) -> int:
from vllm.kvprune.core.runtime import layer_index_from_layer_name
return layer_index_from_layer_name(layer_name)
def do_kv_cache_update_kv_prune(
layer: torch.nn.Module,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
slot_mapping: torch.Tensor,
reshape_and_cache_flash,
kv_cache_dtype: str,
) -> bool:
"""If kvprune handles this step, return True (caller skips default path)."""
state = get_kv_prune_state()
if state is None:
return False
layer_idx = _layer_idx(layer.layer_name)
num_reqs = state.num_reqs
if state.is_prefill:
from vllm.kvprune.compression.prefill import try_prefill_kv_store
if try_prefill_kv_store(layer, key, value, kv_cache):
return True
return False
key_cache, value_cache = kv_cache.unbind(0)
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
slot_mapping,
kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
comp = state.compression_ratio_gpu[:num_reqs]
mask = (comp < _RATIO_ONE).to(torch.int32)
layer_buf = state.logical_seq_lens_gpu[layer_idx, :num_reqs]
if layer_buf.dim() == 2:
layer_buf += mask.unsqueeze(-1)
else:
layer_buf += mask
return True
from __future__ import annotations
import atexit
import inspect
import logging
from pathlib import Path
from typing import Any, List, Optional, Union
import torch.nn as nn
import torch.multiprocessing as mp
from vllm.kvprune.compression.compression_config import (
BatchCompressionParams,
SequenceCompressionParams,
)
from vllm.kvprune.config.engine_config import LLMConfig
from vllm.kvprune.config.sampling_params import SamplingParams
from vllm.kvprune.core.model_runner import ModelRunner
from vllm.kvprune.models import MODEL_REGISTRY
from vllm.kvprune.utils.sequence import Sequence
from transformers import AutoTokenizer
logger = logging.getLogger(__name__)
PromptLike = Union[str, List[int]]
def _infer_stop_token_ids(tokenizer, hf_config) -> list[int]:
"""
Build the set of token ids that should end generation.
Newer HF chat tokenizers often expose ``eos_token_id`` as a *list* of ids.
The engine must not compare generated ids to that list as a single ``int``;
see :attr:`LLMConfig.eos_token_ids` and decode-time ``torch.isin``.
Qwen chat uses ``</think>`` (im_end) as the assistant turn boundary; include it
when present in ``additional_special_tokens`` / ``added_tokens_encoder``. We
avoid loose substring matches like ``\"end\"`` that can tag unrelated tokens.
"""
raw = tokenizer.eos_token_id
ids: list[int] = []
if isinstance(raw, (list, tuple)):
ids.extend(int(x) for x in raw)
elif raw is not None:
ids.append(int(raw))
unk_id = getattr(tokenizer, "unk_token_id", None)
def _maybe_add_tid(tid: int) -> None:
if not isinstance(tid, int) or tid < 0:
return
if unk_id is not None and tid == unk_id:
return
if tid not in ids:
ids.append(tid)
model_type = getattr(hf_config, "model_type", None)
if model_type in ("qwen2", "qwen3", "qwen2_moe", "qwen3_moe"):
enc = getattr(tokenizer, "added_tokens_encoder", None)
if isinstance(enc, dict):
for key, tid in enc.items():
if isinstance(key, str) and "im_end" in key:
_maybe_add_tid(int(tid))
for extra in getattr(tokenizer, "additional_special_tokens", []) or []:
if not isinstance(extra, str) or "im_end" not in extra:
continue
try:
tid = tokenizer.convert_tokens_to_ids(extra)
except (TypeError, ValueError, KeyError):
continue
_maybe_add_tid(tid)
if not ids:
raise ValueError(
"Could not infer stop token ids from the tokenizer; set "
"LLMConfig(eos_token_ids=[...]) explicitly."
)
return ids
def _merge_apply_chat_template_kwargs(
tokenizer,
user_kwargs: Optional[dict[str, Any]],
) -> dict[str, Any]:
"""
Merge user kwargs with defaults for HF chat templates that support them.
Qwen3 (and similar) instruct models expect `add_generation_prompt=True` so
the first generated token continues the assistant turn; without it, output
can repeat punctuation / template fragments. `enable_thinking=False` avoids
the Qwen3 reasoning channel when the tokenizer supports it.
"""
out = dict(user_kwargs or {})
try:
sig = inspect.signature(tokenizer.apply_chat_template)
except (TypeError, ValueError):
return out
if "add_generation_prompt" in sig.parameters and "add_generation_prompt" not in out:
out["add_generation_prompt"] = True
if "enable_thinking" in sig.parameters and "enable_thinking" not in out:
out["enable_thinking"] = False
return out
def _runner_entry(config: LLMConfig, rank: int, evt):
runner = None
try:
runner = ModelRunner(config, rank, evt)
runner.loop()
except Exception as e:
logging.exception(f"Rank {rank}: {repr(e)}")
finally:
if runner is not None:
runner.exit()
class LLMEngine:
"""High-level engine coordinating model runners and scheduling"""
def __init__(self, config: LLMConfig, external_model: nn.Module | None = None):
self.config = config
if self.config.hf_config.model_type not in MODEL_REGISTRY:
raise ValueError(f"Unknown model {self.config.model}")
if config.path is None:
# Local directory: use it directly (no Hub round-trip).
try:
mp = Path(config.model)
if mp.is_dir() and (mp / "config.json").is_file():
self.config.path = str(mp.resolve())
logger.info("Using local model directory for tokenizer: %s", self.config.path)
except OSError:
pass
if config.path is None:
from huggingface_hub import snapshot_download
# Hub repo id: allow downloading missing shards/tokenizer files when cache
# is incomplete (local_files_only=False). Local dirs are handled above.
self.config.path = snapshot_download(
repo_id=config.model,
local_files_only=False,
)
logger.info(
"Resolved Hugging Face snapshot for %s @ %s",
self.config.model,
self.config.path,
)
assert self.config.path is not None
_trust = bool(getattr(self.config.hf_config, "trust_remote_code", False))
# Always load tokenizer from the resolved on-disk tree so we do not re-hit
# the Hub with the repo id (can re-download tokenizer / LFS shards).
self.tokenizer = AutoTokenizer.from_pretrained(
self.config.path,
use_fast=True,
trust_remote_code=_trust,
)
if self.config.eos_token_ids is None:
if self.config.eos != -1:
self.config.eos_token_ids = [int(self.config.eos)]
else:
self.config.eos_token_ids = _infer_stop_token_ids(
self.tokenizer, self.config.hf_config
)
else:
self.config.eos_token_ids = [int(x) for x in self.config.eos_token_ids]
self.config.eos_token_ids = sorted(set(self.config.eos_token_ids))
if self.config.eos == -1:
self.config.eos = int(self.config.eos_token_ids[0])
else:
self.config.eos = int(self.config.eos)
if self.config.eos not in self.config.eos_token_ids:
self.config.eos_token_ids = sorted(
self.config.eos_token_ids + [self.config.eos]
)
if external_model is not None and int(self.config.tensor_parallel_size) != 1:
raise ValueError(
"external_model (shared-weight compactor path) only supports "
"tensor_parallel_size=1"
)
self.ps = []
world_size = int(self.config.tensor_parallel_size)
self.events = []
if world_size > 1:
ctx = mp.get_context("spawn")
for r in range(1, world_size):
event = ctx.Event()
p = ctx.Process(
target=_runner_entry,
args=(self.config, r, event),
daemon=True,
)
p.start()
self.ps.append(p)
self.events.append(event)
self.master_model_runner = ModelRunner(
self.config,
rank=0,
peer_events=self.events,
external_model=external_model,
)
atexit.register(self.exit)
def exit(self):
if getattr(self, "_exited", False):
return
self._exited = True
runner = getattr(self, "master_model_runner", None)
if runner is not None:
try:
runner.exit()
except Exception:
logger.exception("Failed to exit master ModelRunner cleanly")
for p in self.ps:
if p.is_alive():
p.terminate()
p.join(timeout=1.0)
if hasattr(self, "events"):
self.events.clear()
def tokenize_prompt(self, prompt: PromptLike, **tokenizer_kwargs) -> List[int]:
"""
Turn a raw prompt into token IDs.
"""
if isinstance(prompt, str):
return self.tokenizer(prompt, **tokenizer_kwargs)["input_ids"]
else:
return list(prompt)
def detokenize_prompt(
self, sequences: List[Sequence], **detokenizer_kwargs
) -> List[str]:
"""
Turn completed Sequences into strings.
"""
defaults: dict[str, Any] = {"skip_special_tokens": True}
merged = {**defaults, **detokenizer_kwargs}
return self.tokenizer.batch_decode(
[s.completion_token_ids for s in sequences], **merged
)
def _build_sequences(
self,
prompts: List[PromptLike] | PromptLike,
sampling_params: SamplingParams | List[SamplingParams],
per_sequence_compression_params: Optional[
SequenceCompressionParams | List[SequenceCompressionParams]
] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
) -> List[Sequence]:
"""
Build Sequence objects from prompts, sampling params, and optional
per-sequence compression parameters.
"""
tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
if not isinstance(prompts, list):
prompts = [prompts]
if isinstance(sampling_params, SamplingParams):
sampling_params_list: List[SamplingParams] = [sampling_params] * len(
prompts
)
else:
sampling_params_list = sampling_params
assert len(sampling_params_list) == len(prompts), (
"sampling_params list must match prompts length"
)
if per_sequence_compression_params is None:
compression_params_list: List[SequenceCompressionParams] = [
SequenceCompressionParams(1.0) for _ in prompts
]
elif isinstance(per_sequence_compression_params, SequenceCompressionParams):
compression_params_list = [per_sequence_compression_params] * len(prompts)
else:
# list-like
assert len(per_sequence_compression_params) == len(prompts), (
"per_sequence_compression_params list must match prompts length"
)
compression_params_list = list(per_sequence_compression_params)
seqs: List[Sequence] = []
for prompt, sparams, cparams in zip(
prompts, sampling_params_list, compression_params_list
):
token_ids = self.tokenize_prompt(prompt, **tokenizer_kwargs)
if cparams.protected_first_tokens + cparams.protected_last_tokens >= len(token_ids):
cparams.compression_ratio = 1.0
seqs.append(
Sequence(
prompt_token_ids=token_ids,
sampling_params=sparams,
compression_params=cparams,
)
)
return seqs
def generate(
self,
prompts: List[PromptLike] | PromptLike,
sampling_params: SamplingParams | List[SamplingParams],
batch_compression_params: BatchCompressionParams,
*,
per_sequence_compression_params: Union[
List[SequenceCompressionParams], SequenceCompressionParams
] = None,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
detokenizer_kwargs: Optional[dict[str, Any]] = None,
return_sequences: bool = False,
) -> List[str] | tuple[List[str], List[Sequence]]:
"""
Accept prompts and return completed Sequences.
Args:
:param prompts:
Single prompt or list of prompts, each either a raw text prompt,
or pre-tokenized input IDs.
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Compression settings for this batch.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
prompt in the batch.
:param tokenizer_kwargs:
Extra kwargs forwarded to ``tokenizer(...)`` when tokenizing
string prompts.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[Sequence]:
One Sequence per input prompt, with `completion_token_ids`
filled in after generation.
"""
tokenizer_kwargs = {} if tokenizer_kwargs is None else tokenizer_kwargs
detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
seqs = self._build_sequences(
prompts,
sampling_params=sampling_params,
per_sequence_compression_params=per_sequence_compression_params,
tokenizer_kwargs=tokenizer_kwargs,
)
self.master_model_runner.generate(seqs, batch_compression_params)
output_strings = self.detokenize_prompt(seqs, **detokenizer_kwargs)
if return_sequences:
return output_strings, seqs
return output_strings
def generate_chat(
self,
messages_batch: List[List[dict]],
sampling_params: SamplingParams | List[SamplingParams],
batch_compression_params: BatchCompressionParams,
per_sequence_compression_params: Union[
SequenceCompressionParams, List[SequenceCompressionParams]
],
*,
tokenizer_kwargs: Optional[dict[str, Any]] = None,
detokenizer_kwargs: Optional[dict[str, Any]] = None,
return_sequences: bool = False,
) -> List[str] | tuple[List[str], List[Sequence]]:
"""
Convenience API for chat-style prompts using HF `apply_chat_template`.
Args:
:param messages_batch:
List of conversations, where each conversation is a list of
message dicts like:
{"role": "system" | "user" | "assistant", "content": str}
:param sampling_params:
A single SamplingParams for all prompts in this batch or a list of
SamplingParams with the same length as ``prompts``.
:param batch_compression_params:
Batch Level compression settings. Can set compression_method.
:param per_sequence_compression_params:
Per-sequence compression parameters, including the compression
ratio to be applied and the size of the protected regions of the
sequence (how many start tokens and end tokens to keep uncompressed).
If a SequenceCompressionParams instance, the same params will be
applied to all sequences in this batch; if a list is provided,
each SequenceCompressionParams will be attached to the corresponding
conversation in the batch.
:param tokenizer_kwargs:
Passed through to `tokenizer.apply_chat_template`.
:param detokenizer_kwargs:
Passed through to `tokenizer.batch_decode`.
:param return_sequences:
Whether to return sequence objects or not
Returns:
:return List[str] or tuple[List[str], List[Sequence]]:
One string per conversation.
"""
prompts_token_ids: List[List[int]] = []
tokenizer_kwargs = _merge_apply_chat_template_kwargs(
self.tokenizer, tokenizer_kwargs
)
detokenizer_kwargs = {} if detokenizer_kwargs is None else detokenizer_kwargs
for messages in messages_batch:
input_ids = self.tokenizer.apply_chat_template(
messages,
tokenize=True,
**tokenizer_kwargs,
)
if hasattr(input_ids, "tolist"):
input_ids = input_ids.tolist()
prompts_token_ids.append(input_ids)
return self.generate(
prompts_token_ids,
sampling_params=sampling_params,
batch_compression_params=batch_compression_params,
per_sequence_compression_params=per_sequence_compression_params,
tokenizer_kwargs=tokenizer_kwargs,
detokenizer_kwargs=detokenizer_kwargs,
return_sequences=return_sequences,
)
def generate_from_sequences(
self,
seqs: List[Sequence],
batch_compression_params: BatchCompressionParams,
) -> List[Sequence]:
"""
Args:
:param seqs:
List of Sequence instances
:param batch_compression_params:
Compression settings.
Returns:
:return List[Sequence]:
Same list, mutated in-place with completions.
"""
self.master_model_runner.generate(seqs, batch_compression_params)
return seqs
import logging
import os
from typing import Iterable, List, Optional
import torch
from vllm.kvprune.config.engine_config import LLMConfig
from vllm.kvprune.kv_cache.page_table import KVAllocationStatus, PagedKVCache
from vllm.kvprune.utils.tp_utils import kv_heads_shard_divisor
from torch import nn
logger = logging.getLogger(__name__)
class KVCacheManager:
def __init__(
self,
rank: int,
config: LLMConfig,
*,
device: str | None = None,
):
super().__init__()
hf_config = config.hf_config
self.rank = rank
self.gpu_frac = config.gpu_memory_utilization
self.page_size = config.kvcache_page_size
self.world_size = config.tensor_parallel_size
self.max_num_batches = config.max_num_seqs
self.max_model_len = config.max_model_len
self.num_layers = hf_config.num_hidden_layers
self.model_dtype = hf_config.torch_dtype
self.head_dim = getattr(hf_config, "head_dim", None)
self.max_pages_per_batch = (
self.max_model_len + self.page_size - 1
) // self.page_size
_ws = kv_heads_shard_divisor()
self.num_kv_heads = hf_config.num_key_value_heads // _ws
assert hf_config.num_key_value_heads % _ws == 0, (
"tensor-parallel world size needs to divide num_kv_heads"
)
self._cache_device = device if device is not None else f"cuda:{self.rank}"
self.num_pages = None
self.paged_cache: Optional[PagedKVCache] = None
self.max_batched_tokens = None
self.seq_id_to_batch = {}
def allocate_sequences(
self, seq_ids: List[int], max_positions: List[int]
) -> (bool, Optional[torch.Tensor]):
batch_mapping = []
for seq_id, len_to_alloc in zip(seq_ids, max_positions):
if seq_id not in self.seq_id_to_batch:
batch_id = self.paged_cache.new_batch()
if batch_id is None:
logger.warning("Failed to allocate batch!")
return False, None
self.seq_id_to_batch[seq_id] = int(batch_id)
batch_mapping.append(self.seq_id_to_batch[seq_id])
if (
alloc_status := self.paged_cache.reserve_tokens(
self.seq_id_to_batch[seq_id], len_to_alloc
)
) != KVAllocationStatus.SUCCESS:
logger.warning(f"Failed to allocate pages ({alloc_status})!")
return False, None
batch_mapping = torch.as_tensor(batch_mapping, dtype=torch.int32, device="cuda")
return True, batch_mapping
def free_sequences(self, seq_ids: Iterable[int]):
for seq_id in seq_ids:
global_batch_id = self.seq_id_to_batch.pop(seq_id, None)
self.paged_cache.free_batch(global_batch_id)
def init_cache(self, model: nn.Module):
self.num_pages = self.get_num_pages(self.gpu_frac, self.max_pages_per_batch)
self.paged_cache = PagedKVCache(
num_layers=self.num_layers,
H_kv=self.num_kv_heads,
head_dim=self.head_dim,
page_size=self.page_size,
num_pages=int(self.num_pages),
max_num_batches=self.max_num_batches,
device=self._cache_device,
dtype=self.model_dtype,
max_logical_pages_per_head=int(self.max_pages_per_batch),
)
self._assign_cache_to_layers(model)
def _assign_cache_to_layers(self, model) -> None:
for layer_index, layer in enumerate(model.model.layers):
attn = layer.self_attn.attn
k, v, pt, bh = self.paged_cache.layer_slices(layer_index)
attn.k_cache = k
attn.v_cache = v
attn.page_table = pt
attn.bh_seq_lens = bh
attn.page_size = self.page_size
def get_num_pages(self, frac: float, n_logical_pages_max: int):
free, total = torch.cuda.mem_get_info()
used = total - free
stats = torch.cuda.memory_stats()
peak = int(stats["allocated_bytes.all.peak"])
current = int(stats["allocated_bytes.all.current"])
bytes_for_kv_budget = int(total * frac * 0.9) - used - peak + current
if bytes_for_kv_budget <= 0:
# Standalone compactor: ``frac`` is a fraction of total VRAM. When a second
# engine shares the GPU with vLLM (shared weights), most VRAM is already
# committed; the formula above goes negative. Fall back to a slice of
# *currently free* memory for the compactor KV pool.
free_frac = float(
os.environ.get("VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC", "0.55")
)
free_frac = max(0.05, min(free_frac, 0.95))
bytes_for_kv_budget = int(free * free_frac)
logger.warning(
"KV cache budget from gpu_memory_utilization (%.2f) is exhausted "
"(%.2f MiB free on device); using %.0f%% of free memory (~%.2f MiB) "
"for compactor KV (set VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC to adjust).",
frac,
free / (1024**2),
free_frac * 100,
bytes_for_kv_budget / (1024**2),
)
if bytes_for_kv_budget <= 0:
raise RuntimeError(
"Insufficient memory for compactor KV cache: no free GPU memory left "
"after the primary vLLM engine. Lower vLLM gpu_memory_utilization or "
"max_model_len, shorten prompts, or run compactor-only / vLLM-only "
"sessions. Raising gpu_memory_utilization here does not help."
)
# page_table[L, B, H_kv, N_LOGICAL_PAGES_MAX] + bh_seq_lens[L, B, H_kv]
int32_sz = torch.empty((), dtype=torch.int32).element_size() # 4
page_table_bytes_per_layer = (
self.max_num_batches
* self.num_kv_heads
* n_logical_pages_max
* int32_sz # page_table
+ self.max_num_batches * self.num_kv_heads * int32_sz
)
total_page_table_bytes = self.num_layers * page_table_bytes_per_layer
kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
if kv_bytes_net <= 0:
# Tight VRAM: metadata alone can exceed the first budget; reserve page
# tables plus a slice of remaining free for KV tensors.
bytes_for_kv_budget = min(
int(free * 0.95),
total_page_table_bytes + max(int(free * 0.25), 8 * 1024 * 1024),
)
kv_bytes_net = bytes_for_kv_budget - total_page_table_bytes
if kv_bytes_net <= 0:
raise RuntimeError(
"page-table footprint exceeds available GPU memory for compactor KV. "
f"Reduce vLLM max_num_seqs (compactor uses {self.max_num_batches}) "
f"or max_model_len ({self.max_model_len}), or free GPU memory."
)
dtype_sz = torch.empty((), dtype=self.model_dtype).element_size()
bytes_per_page_across_layers = self.num_layers * (
2 * self.page_size * self.head_dim * dtype_sz
)
return max(1, kv_bytes_net // bytes_per_page_across_layers)
def estimate_max_batched_tokens(
self,
warmup_tokens: int,
bytes_used_before_warmup: int,
bytes_peak_after_warmup: int,
) -> int:
"""
Estimate the max total number of tokens that can be processed concurrently
without OOM.
"""
assert warmup_tokens > 0, "warmup_tokens must be > 0"
# activation bytes per token
warmup_delta = max(
0, int(bytes_peak_after_warmup) - int(bytes_used_before_warmup)
)
bytes_per_token = max(1, (warmup_delta + warmup_tokens - 1) // warmup_tokens)
free, total = torch.cuda.mem_get_info()
target = int(total * self.gpu_frac)
used_now = int(total - free)
# reserve headroom equal to the gap between peak and current allocations seen so far
stats = torch.cuda.memory_stats()
peak_cur = int(stats.get("allocated_bytes.all.peak", 0))
cur_now = int(stats.get("allocated_bytes.all.current", 0))
cushion = max(0, peak_cur - cur_now)
activation_budget = int(max(0, target - used_now - cushion) * 0.95)
max_tokens_per_batch = activation_budget // bytes_per_token
max_tokens_in_cache = (self.num_pages * self.page_size) // self.num_kv_heads
# round to lower multiple of page size
max_tokens_per_batch = (max_tokens_per_batch // self.page_size) * self.page_size
max_tokens_in_cache = (max_tokens_in_cache // self.page_size) * self.page_size
# When vLLM shares the same GPU, ``used_now`` often exceeds ``target`` (same
# situation as ``get_num_pages``), so activation_budget is ~0 and
# ``max_tokens_per_batch`` rounds to 0 or one page. The min(...) would then
# cap prefill at ~page_size tokens (e.g. 32) even though the compactor KV pool
# is large — no prompt longer than that can be scheduled. Prefer KV capacity
# (capped by max_model_len) whenever activation math yields only a token or two.
if (
max_tokens_in_cache > 0
and max_tokens_per_batch <= self.page_size
and max_tokens_in_cache > max_tokens_per_batch
):
max_tokens_per_batch = min(max_tokens_in_cache, self.max_model_len)
self.max_batched_tokens = min(max_tokens_in_cache, max_tokens_per_batch)
# Last resort: allow at least one page when KV exists but min(...) is still 0.
if self.max_batched_tokens == 0 and self.num_pages > 0 and max_tokens_in_cache > 0:
self.max_batched_tokens = min(max_tokens_in_cache, self.page_size)
return self.max_batched_tokens
@property
def num_free_batches(self) -> int:
return len(self.paged_cache.free_batches)
@property
def num_free_pages(self) -> int:
return min(len(fp) for fp in self.paged_cache.free_pages)
def reclaim_pages(
self,
seq_ids_to_reclaim: Iterable[int],
future_reserved_buffer: List[int] | torch.Tensor,
) -> int:
approximate_bytes_freed = 0
for i, seq_id in enumerate(seq_ids_to_reclaim):
batch_idx = self.seq_id_to_batch[seq_id]
approximate_bytes_freed += self.paged_cache.reclaim_pages(
batch_idx, future_reserved_buffer[i]
)
return approximate_bytes_freed
import atexit
import logging
import os
import inspect
from typing import Any, List, Optional
import torch
import torch.nn as nn
import torch.distributed as dist
from vllm.kvprune.attention.sparse_decode_kernel import num_splits_heuristic
from vllm.kvprune.compression.compression_config import BatchCompressionParams
from vllm.kvprune.config.constants import RESERVED_BATCH
from vllm.kvprune.config.engine_config import LLMConfig, KvpruneAttentionSchedule
from vllm.kvprune.core.memory_manager import KVCacheManager
from vllm.kvprune.core.scheduler import Scheduler
from vllm.kvprune.layers.sampler import Sampler
from vllm.kvprune.models import MODEL_REGISTRY
from vllm.kvprune.utils.arguments import (
DecodeBatchArguments,
DecodeBatchOutput,
PackedTensorArguments,
PrefillBatchArguments,
)
from vllm.kvprune.utils.context import CompressionContext, reset_context, set_context
from vllm.kvprune.utils.kv_dist import barrier_sync, broadcast_from_tp_rank0
from vllm.kvprune.utils.sequence import Sequence
from torch.multiprocessing import Event
from tqdm import tqdm
logger = logging.getLogger(__name__)
class ModelRunner:
"""Per-rank execution loop. Manages model, sampler, KV cache, and warmup"""
def __init__(
self,
config: LLMConfig,
rank: int,
batch_ready: Optional[Event] = None,
peer_events: List[Event] = None,
external_model: Optional[nn.Module] = None,
*,
embedded_in_vllm_worker: bool = False,
device: Optional[torch.device] = None,
):
self.config = config
self.embedded_in_vllm_worker = embedded_in_vllm_worker
if embedded_in_vllm_worker:
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
tp_ws = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
if tp_ws != config.tensor_parallel_size:
raise RuntimeError(
f"tensor parallel world size {tp_ws} != "
f"LLMConfig.tensor_parallel_size {config.tensor_parallel_size}"
)
self.rank = tp_rank
_dev = device if device is not None else torch.device(
f"cuda:{torch.cuda.current_device()}"
)
if not dist.is_initialized():
raise RuntimeError(
"embedded_in_vllm_worker requires torch.distributed to be "
"initialized (vLLM worker)."
)
if dist.get_world_size() != tp_ws:
raise NotImplementedError(
"KV-prune compactor embedded in vLLM currently requires "
"dist.get_world_size() == tensor_parallel_size "
"(pipeline_parallel_size=1, data_parallel_size=1). "
f"Got dist.get_world_size()={dist.get_world_size()}, "
f"tp_ws={tp_ws}."
)
else:
self.rank = rank
_dev = device if device is not None else torch.device(f"cuda:{rank}")
self._device = _dev
assert config.eos_token_ids is not None and len(config.eos_token_ids) > 0, (
"LLMConfig.eos_token_ids must be set (filled in LLMEngine from tokenizer)."
)
self._stop_token_ids = torch.tensor(
config.eos_token_ids, dtype=torch.int64, device=_dev
)
hf_config = config.hf_config
self.enforce_eager = config.enforce_eager
if config.attention_schedule == KvpruneAttentionSchedule.PDFA:
if not self.enforce_eager and self.rank == 0:
logger.info(
"attention_schedule=PDFA: disabling compactor decode CUDA graphs "
"(FlashAttention decode path)."
)
self.enforce_eager = True
# Embedded in vLLM worker (TP>1): respect :attr:`LLMConfig.enforce_eager` from
# ``v1_tp_runner._apply_compactor_env_overrides``. Set
# ``VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0`` to force eager if graph replay is unstable
# with shared vLLM VRAM / streams / NCCL on your stack.
if embedded_in_vllm_worker:
_tp_graph = os.environ.get(
"VLLM_KVPRUNE_TP_EMBEDDED_GRAPH", "1"
).strip().lower()
if _tp_graph in ("0", "false", "no"):
if not self.enforce_eager:
logger.info(
"embedded_in_vllm_worker: VLLM_KVPRUNE_TP_EMBEDDED_GRAPH=0 → "
"forcing compactor enforce_eager=True (skip compactor CUDA graph "
"capture)."
)
self.enforce_eager = True
self.world_size = config.tensor_parallel_size
self.leverage_sketch_size = config.leverage_sketch_size
self.show_progress_bar = config.show_progress_bar
self.max_num_batches = config.max_num_seqs
self.max_model_len = config.max_model_len
self.num_layers = hf_config.num_hidden_layers
self.model_dtype = hf_config.torch_dtype
self.head_dim = getattr(hf_config, "head_dim", None)
init_kwargs = {}
if not embedded_in_vllm_worker:
if "device_id" in inspect.signature(dist.init_process_group).parameters:
init_kwargs["device_id"] = torch.device(f"cuda:{rank}")
if not dist.is_initialized():
dist.init_process_group(
"nccl",
f"tcp://localhost:{config.nccl_port}",
world_size=self.world_size,
rank=rank,
**init_kwargs,
)
else:
ws = dist.get_world_size()
if ws != self.world_size:
raise RuntimeError(
"torch.distributed is already initialized with "
f"world_size={ws}, but compactor ModelRunner expects "
f"tensor_parallel_size={self.world_size}. "
"Use tensor_parallel_size matching the active process group "
"(typically 1 when sharing weights with vLLM)."
)
torch.cuda.set_device(_dev)
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(hf_config.torch_dtype)
torch.set_default_device("cuda")
model_type = hf_config.model_type
if external_model is not None:
self.model = external_model
else:
self.model = MODEL_REGISTRY[model_type](hf_config)
self.model.load_model(
config.path, use_tqdm=self.is_master and self.show_progress_bar
)
self.sampler = Sampler()
pre_warmup_mem = torch.cuda.memory_stats().get("allocated_bytes.all.current", 0)
# No paged KV yet: FA-only varlen path (see :meth:`warmup`).
self.warmup(num_warmup_tokens=self.max_model_len, with_kv=False)
post_warmup_peak = torch.cuda.memory_stats().get("allocated_bytes.all.peak", 0)
self.kv_manager = KVCacheManager(
self.rank, config, device=str(self._device)
)
self.kv_manager.init_cache(self.model)
self.store_stream: Optional[torch.cuda.Stream] = torch.cuda.Stream()
torch.set_default_device("cpu")
torch.set_default_dtype(default_dtype)
self.batch_ready = batch_ready
self.peer_events = peer_events if peer_events is not None else []
# Embedded TP peers: session end is signaled via TP-group broadcast in
# maybe_release_peers (no multiprocessing.Event — not pickleable over RPC).
self._embedded_peer_continue = True
self.captured_graphs = {}
self.min_captured_len = {}
self.max_batched_tokens = self.kv_manager.estimate_max_batched_tokens(
self.max_model_len, pre_warmup_mem, post_warmup_peak
)
if self.is_master:
logger.info(f"Estimated max batched tokens of {self.max_batched_tokens}")
self.warmup(num_warmup_tokens=self.max_model_len, with_kv=True)
if not self.enforce_eager:
bs = [1 << i for i in range(self.max_num_batches.bit_length())]
for bs in (
tqdm(bs, desc="Capturing CUDA Graphs")
if self.is_master and self.show_progress_bar
else bs
):
for seq_len in [1024, 4096, 8192, 16384]:
self.capture_cudagraph(bs, seq_len)
if not self.captured_graphs:
logger.warning(
"No compactor CUDA graphs were captured (KV budget tight or "
"allocate_sequences failed during capture). Using eager decode "
"for this session."
)
self.enforce_eager = True
self.packed_args = PackedTensorArguments(
rank=self.rank,
max_batched_tokens=self.max_batched_tokens,
config=self.config,
device=self._device,
use_tp_group_for_collectives=embedded_in_vllm_worker,
)
atexit.register(self.exit)
@torch.inference_mode()
def warmup(self, num_warmup_tokens: int, *, with_kv: bool):
sched = (
self.config.attention_schedule
if with_kv
else KvpruneAttentionSchedule.FA_PREFILL_TRITON_DECODE
)
if self.rank == 0:
logger.info(
"Warming up compactor attention (%s KV init): schedule=%s",
"after" if with_kv else "before",
sched.name,
)
device = self._device
input_ids = torch.tensor(
[self.config.eos] * num_warmup_tokens, device=device, dtype=torch.int64
)
positions = torch.arange(num_warmup_tokens, device=device, dtype=torch.int64)
cu_seqlens_q = torch.tensor(
[0, num_warmup_tokens], device=device, dtype=torch.int32
)
cu_seqlens_k = torch.tensor(
[0, num_warmup_tokens], device=device, dtype=torch.int32
)
if with_kv:
success, batch_mapping = self.kv_manager.allocate_sequences(
[-1], [num_warmup_tokens]
)
assert success
else:
batch_mapping = None
set_context(
is_prefill=True,
do_compression=False,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
cu_seqlens_q_host=(0, num_warmup_tokens),
cu_seqlens_k_host=(0, num_warmup_tokens),
max_seqlen_q=num_warmup_tokens,
max_seqlen_k=num_warmup_tokens,
batch_mapping=batch_mapping,
attention_schedule=sched,
)
for _ in range(2):
torch.cuda.reset_peak_memory_stats()
h = self.model(input_ids, positions)
self.model.compute_logits(h)
barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
if with_kv:
self.kv_manager.paged_cache.bh_seq_lens.index_fill_(
1, batch_mapping.to(torch.long), 0
)
reset_context()
if with_kv:
self.kv_manager.free_sequences([-1])
def exit(self):
if getattr(self, "_exited", False):
return
self._exited = True
try:
if hasattr(self, "captured_graphs"):
self.captured_graphs.clear()
finally:
if getattr(self, "embedded_in_vllm_worker", False):
return
if dist.is_initialized():
dist.destroy_process_group()
def loop(self):
while True:
if self.batch_ready.wait(1.0):
self._process_batches_peer()
@torch.inference_mode()
def run_prefill(
self, prefill_args: PrefillBatchArguments, batch_mapping: torch.Tensor
):
assert prefill_args.B > 0 and prefill_args.N > 0
max_bh_len = (
self.kv_manager.paged_cache.bh_seq_lens.index_select(1, index=batch_mapping)
.max()
.item()
)
compression_context = CompressionContext(
compression_method=prefill_args.compression_method,
compression_chunk_size=prefill_args.compression_chunk_size,
batch_tokens_to_retain=prefill_args.batch_tokens_to_retain,
max_tokens_to_retain=prefill_args.max_tokens_to_retain,
context_lens=prefill_args.context_lens.tolist(),
PHI=prefill_args.PHI,
sketch_dimension=self.leverage_sketch_size,
protected_first_tokens=prefill_args.protected_first,
protected_last_tokens=prefill_args.protected_last,
compression_ratio=prefill_args.compression_ratio,
)
cu_q_host = tuple(
int(x) for x in prefill_args.cu_seqlens_q.detach().cpu().view(-1).tolist()
)
cu_k_host = tuple(
int(x) for x in prefill_args.cu_seqlens_k.detach().cpu().view(-1).tolist()
)
set_context(
is_prefill=True,
do_compression=prefill_args.do_compression,
cu_seqlens_q=prefill_args.cu_seqlens_q,
cu_seqlens_k=prefill_args.cu_seqlens_k,
cu_seqlens_q_host=cu_q_host,
cu_seqlens_k_host=cu_k_host,
max_seqlen_q=prefill_args.max_seqlen_q,
max_seqlen_k=prefill_args.max_seqlen_k,
batch_mapping=batch_mapping,
max_bh_len=max_bh_len,
compression_context=compression_context,
STORE_STREAM=self.store_stream,
attention_schedule=self.config.attention_schedule,
)
# int32 token ids break vLLM-delegated embedding (expects long indices) on some paths.
_iid = (
prefill_args.input_ids
if prefill_args.input_ids.dtype == torch.int64
else prefill_args.input_ids.long()
)
_pos = (
prefill_args.positions
if prefill_args.positions.dtype == torch.int64
else prefill_args.positions.long()
)
hidden = self.model(_iid, _pos)
logits = self.model.compute_logits(hidden)
reset_context()
return logits
def maybe_broadcast(self, tensor: torch.Tensor, *, label: str = "tensor") -> None:
if self.world_size > 1:
broadcast_from_tp_rank0(
tensor, use_tp_group=self.embedded_in_vllm_worker
)
return None
def maybe_release_peers(self, do_release=False):
if self.world_size <= 1:
return
if self.embedded_in_vllm_worker:
flag = torch.zeros(1, dtype=torch.int32, device=self._device)
if self.is_master:
flag[0] = 0 if do_release else 1
broadcast_from_tp_rank0(flag, use_tp_group=True)
if not self.is_master:
self._embedded_peer_continue = bool(flag[0].item())
barrier_sync(use_tp_group=True)
return
if self.is_master:
if do_release:
for event in self.peer_events:
event.clear()
barrier_sync(use_tp_group=False)
else:
barrier_sync(use_tp_group=False)
def _peer_outer_loop_active(self) -> bool:
if self.batch_ready is not None:
return self.batch_ready.is_set()
if self.embedded_in_vllm_worker:
return self._embedded_peer_continue
return False
@torch.inference_mode()
def generate(
self,
all_sequences: List[Sequence],
batch_compression_params: Optional[BatchCompressionParams] = None,
):
assert self.is_master, "generate can only be called on the master process"
if not self.embedded_in_vllm_worker:
for begin_execution_event in self.peer_events:
begin_execution_event.set()
if batch_compression_params is None:
batch_compression_params = BatchCompressionParams()
self._process_batches_master(all_sequences, batch_compression_params)
@property
def is_master(self):
return self.rank == 0
@torch.inference_mode()
def _process_batches_master(
self,
all_sequences: List[Sequence],
batch_compression_params: BatchCompressionParams,
):
assert self.is_master
compression_details = f"Applying Compression Method: {batch_compression_params.compression_method}"
if any(seq.compression_params.compression_ratio < 1.0 for seq in all_sequences):
logger.info(compression_details)
scheduler = Scheduler(
all_sequences=all_sequences,
kv_manager=self.kv_manager,
use_tqdm=self.show_progress_bar,
)
decode_batch = DecodeBatchArguments()
decode_flags = torch.empty(2, dtype=torch.int32, device=self._device)
while not scheduler.is_finished():
sequences = scheduler.get_prefill_batch()
if not sequences:
if scheduler.pending_sequence_ids:
raise RuntimeError(
"KV-prune compactor cannot schedule any prefill (KV/token budget). "
f"max_batched_tokens={self.kv_manager.max_batched_tokens}, "
f"pending_sequences={len(scheduler.pending_sequence_ids)}. "
"Lower v1 gpu_memory_utilization / max_model_len, set "
"VLLM_KVPRUNE_RELEASE_V1_KV=1 to discard v1 KV (sleep+wake), "
"or free GPU memory. Diagnostics: "
f"{scheduler.diagnose_prefill_failure()}"
)
# Pending is empty: either finished or decode-only continuation.
if decode_batch.token_ids is None:
break
run_decode = True
occupancy = -1
else:
seq_ids_cpu = [seq.seq_id for seq in sequences]
scheduler.add_running_sequence_ids(seq_ids_cpu, update_status=True)
temps = torch.tensor(
[s.sampling_params.temperature for s in sequences],
dtype=torch.float32,
pin_memory=True,
).to(device=self._device, non_blocking=True)
prefill_arguments = self.packed_args.build_prefill_args(
sequences, batch_compression_params=batch_compression_params
)
max_ctx_lens = (
prefill_arguments.max_new_tokens + prefill_arguments.context_lens
)
success, batch_mapping = self.kv_manager.allocate_sequences(
seq_ids_cpu, max_ctx_lens.tolist()
)
assert success, "failed to allocate pages for sequences"
logits = self.run_prefill(prefill_arguments, batch_mapping)
# Must match prefill `positions` dtype (int64). `context_lens` is int32
# from the packed buffer; using int32 here breaks RoPE indexing
# (`cos_sin_cache[positions]`) on CUDA for decode vs prefill.
positions = prefill_arguments.context_lens.to(dtype=torch.int64)
token_ids = self.sampler(logits, temps)
# Prefill KV writes + bh_seq_lens updates run on STORE_STREAM; reclaim
# reads bh_seq_lens on the default stream and must not race.
if self.store_stream is not None:
torch.cuda.default_stream().wait_stream(self.store_stream)
# TODO: synchronize page counts accross dist
if self.world_size == 1:
self.kv_manager.reclaim_pages(
seq_ids_cpu, prefill_arguments.max_new_tokens
)
# with logging_redirect_tqdm():
# logger.info(
# f"Reclaimed {reclaimed_bytes / 1e6:.2f} MB from the KV cache"
# )
if scheduler.any_pending_sequences():
num_pending_batches = (
0
if decode_batch.token_ids is None
else decode_batch.token_ids.shape[0]
)
occupancy = int((num_pending_batches + len(seq_ids_cpu)) * 0.66)
else:
occupancy = -1
run_decode = not scheduler.can_prefill_another_batch()
decode_batch = decode_batch.update(
batch_mapping,
token_ids,
positions,
max_ctx_lens,
prefill_arguments.seq_ids,
temps,
occupancy,
)
if self.world_size > 1:
decode_flags[0] = int(run_decode)
decode_flags[1] = occupancy
self.maybe_broadcast(decode_flags, label="decode_flags")
if not run_decode:
continue
if self.store_stream is not None:
torch.cuda.default_stream().wait_stream(self.store_stream)
decode_output, decode_batch = self.run_decode_loop(decode_batch)
finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
decode_batch.seq_ids.tolist()
)
scheduler.record_finished_sequence_ids(
finished_sequence_ids, update_status=True
)
self.kv_manager.free_sequences(finished_sequence_ids)
self.maybe_release_peers(scheduler.is_finished())
scheduler.update_sequences(
decode_output.output_tokens.tolist(),
decode_output.output_seq_ids.tolist(),
)
scheduler.close()
@torch.inference_mode()
def run_peer_session(self) -> None:
"""Non-master TP ranks: run one peer session (used when embedded in vLLM)."""
if self.embedded_in_vllm_worker:
self._embedded_peer_continue = True
self._process_batches_peer()
@torch.inference_mode()
def _process_batches_peer(self):
assert not self.is_master
scheduler = Scheduler([], kv_manager=self.kv_manager)
decode_batch = DecodeBatchArguments()
decode_flags = torch.empty(2, dtype=torch.int32, device=self._device)
while self._peer_outer_loop_active():
prefill_arguments = self.packed_args.build_prefill_args()
B = prefill_arguments.B
max_ctx_lens = (
prefill_arguments.max_new_tokens + prefill_arguments.context_lens
)
seq_ids_cpu = prefill_arguments.seq_ids.tolist()
scheduler.add_running_sequence_ids(seq_ids_cpu)
success, batch_mapping = self.kv_manager.allocate_sequences(
seq_ids_cpu, max_ctx_lens.tolist()
)
assert success, "failed to allocate pages for sequences"
self.run_prefill(prefill_arguments, batch_mapping)
positions = prefill_arguments.context_lens.to(dtype=torch.int64)
self.maybe_broadcast(decode_flags, label="decode_flags")
run_decode = bool(decode_flags[0].item())
occupancy = int(decode_flags[1].item())
token_ids = torch.empty(B, dtype=torch.int64, device=self._device)
decode_batch = decode_batch.update(
batch_mapping,
token_ids,
positions,
max_ctx_lens,
prefill_arguments.seq_ids,
None, # temps not used in peer process
occupancy,
)
if not run_decode:
continue
if self.store_stream is not None:
torch.cuda.default_stream().wait_stream(self.store_stream)
_, decode_batch = self.run_decode_loop(decode_batch)
finished_sequence_ids = scheduler.get_finished_sequence_ids_from_unfinished(
decode_batch.seq_ids.tolist()
)
scheduler.record_finished_sequence_ids(finished_sequence_ids)
self.kv_manager.free_sequences(finished_sequence_ids)
self.maybe_release_peers()
scheduler.close()
@torch.inference_mode()
def run_decode_loop(
self,
decode_batch: DecodeBatchArguments,
) -> tuple[DecodeBatchOutput, DecodeBatchArguments]:
if self.is_master:
num_stashed_batches = decode_batch.num_stashed_batches
tok_buffer = [
decode_batch.token_ids[num_stashed_batches:].to(
"cpu", non_blocking=True
)
]
seq_buffer = [
decode_batch.seq_ids[num_stashed_batches:].to("cpu", non_blocking=True)
]
while True:
self.maybe_broadcast(decode_batch.token_ids, label="decode_token_ids")
not_stopped = ~torch.isin(decode_batch.token_ids, self._stop_token_ids)
running_batches = (decode_batch.positions < decode_batch.max_ctx_lens) & (
not_stopped
)
decode_batch.token_ids = torch.masked_select(
decode_batch.token_ids, running_batches
)
decode_batch.positions = torch.masked_select(
decode_batch.positions, running_batches
)
decode_batch.batch_mapping = torch.masked_select(
decode_batch.batch_mapping, running_batches
)
decode_batch.max_ctx_lens = torch.masked_select(
decode_batch.max_ctx_lens, running_batches
)
decode_batch.seq_ids = torch.masked_select(
decode_batch.seq_ids, running_batches
)
if self.is_master:
decode_batch.temps = torch.masked_select(
decode_batch.temps, running_batches
)
num_remaining = decode_batch.token_ids.numel()
if (
num_remaining == 0
or num_remaining <= decode_batch.desired_batch_occupancy
):
decode_batch.num_stashed_batches = num_remaining
break
logits = self._decode_step_logits(decode_batch)
if self.is_master:
decode_batch.token_ids = self.sampler(logits, decode_batch.temps)
tok_buffer.append(decode_batch.token_ids.to("cpu", non_blocking=True))
seq_buffer.append(decode_batch.seq_ids.to("cpu", non_blocking=True))
decode_batch.positions += 1
if self.is_master:
# non_blocking D2H copies must finish before cat/tolist read CPU data.
torch.cuda.synchronize()
output = DecodeBatchOutput(
output_tokens=torch.cat(tok_buffer),
output_seq_ids=torch.cat(seq_buffer),
)
else:
output = DecodeBatchOutput(None, None)
return output, decode_batch
def _decode_logits_eager(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
batch_mapping: torch.Tensor,
):
set_context(
is_prefill=False,
do_compression=False,
batch_mapping=batch_mapping,
attention_schedule=self.config.attention_schedule,
)
_iid = input_ids if input_ids.dtype == torch.int64 else input_ids.long()
_pos = positions if positions.dtype == torch.int64 else positions.long()
hidden = self.model(_iid, _pos)
return self.model.compute_logits(hidden)
@torch.inference_mode()
def _decode_step_logits(self, decode_batch: DecodeBatchArguments):
"""Graph decode when possible; otherwise eager (never raises on missing graph)."""
if self.enforce_eager or not self.captured_graphs:
return self._decode_logits_eager(
decode_batch.token_ids,
decode_batch.positions,
decode_batch.batch_mapping,
)
try:
return self.run_graph_decode(
decode_batch.token_ids,
decode_batch.positions,
decode_batch.batch_mapping,
)
except Exception as e:
logger.warning(
"CUDA graph decode failed (%s); switching to eager decode for "
"remaining steps.",
e,
)
self.enforce_eager = True
return self._decode_logits_eager(
decode_batch.token_ids,
decode_batch.positions,
decode_batch.batch_mapping,
)
@torch.inference_mode()
def run_graph_decode(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
batch_mapping: torch.Tensor,
):
bs = input_ids.shape[0]
max_k = int(positions.max())
graph_dict = self.get_cuda_graph(bs, max_k)
if graph_dict is None:
return self._decode_logits_eager(input_ids, positions, batch_mapping)
set_context(
is_prefill=False,
do_compression=False,
batch_mapping=batch_mapping,
attention_schedule=self.config.attention_schedule,
)
graph_dict["input_ids"][:bs] = input_ids
graph_dict["positions"][:bs] = positions
graph_dict["batch_mapping"].fill_(RESERVED_BATCH)
graph_dict["batch_mapping"][:bs] = batch_mapping
graph_dict["graph"].replay()
logits_out = graph_dict["logits"]
return logits_out[:bs].contiguous()
@torch.inference_mode()
def capture_cudagraph(self, batch_size: int, max_seqlen_k: int):
barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
device = torch.device("cuda")
logger.debug(
f"Capturing CUDA graph for batch size {batch_size} ({max_seqlen_k} tokens)"
)
_g_input_ids = torch.zeros(batch_size, dtype=torch.int32, device=device)
_g_positions = torch.zeros(batch_size, dtype=torch.int64, device=device)
_g_hidden = None
key_split = num_splits_heuristic(
batch_size * self.kv_manager.num_kv_heads,
max_seq_len=max_seqlen_k,
num_sms=torch.cuda.get_device_properties(device).multi_processor_count,
max_splits=12,
)
success, _g_batch_mapping = self.kv_manager.allocate_sequences(
list(range(batch_size)), [256] * batch_size
)
if not success:
# Shared GPU with vLLM: compactor KV pool is small; large batch capture
# often cannot reserve [256]*batch_size per sequence. Skip this graph.
logger.warning(
"Skipping CUDA graph capture for batch_size=%s max_seqlen_k=%s "
"(KV allocate_sequences failed; decode will use eager or other graphs).",
batch_size,
max_seqlen_k,
)
barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
return
set_context(
is_prefill=False,
do_compression=False,
batch_mapping=_g_batch_mapping,
key_split=key_split,
attention_schedule=self.config.attention_schedule,
)
_gw = self.model(_g_input_ids, _g_positions)
self.model.compute_logits(_gw)
barrier_sync(use_tp_group=self.embedded_in_vllm_worker)
decode_graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(decode_graph):
_g_hidden = self.model(_g_input_ids, _g_positions)
_g_logits = self.model.compute_logits(_g_hidden)
graph_vars = {
"graph": decode_graph,
"input_ids": _g_input_ids,
"positions": _g_positions,
"batch_mapping": _g_batch_mapping,
"hidden": _g_hidden,
"logits": _g_logits,
"key_split": key_split,
}
if batch_size not in self.captured_graphs:
self.captured_graphs[batch_size] = {}
self.min_captured_len[batch_size] = float("inf")
self.captured_graphs[batch_size][max_seqlen_k] = graph_vars
self.min_captured_len[batch_size] = min(
max_seqlen_k, self.min_captured_len[batch_size]
)
self.kv_manager.free_sequences(list(range(batch_size)))
def get_cuda_graph(
self, batch_size: int, max_seqlen_k: int
) -> Optional[dict[str, Any]]:
"""Return a captured graph dict, or None if no compatible capture exists."""
if not self.captured_graphs:
return None
eligible_bs = [x for x in self.captured_graphs.keys() if x >= batch_size]
if not eligible_bs:
return None
bs_key = min(eligible_bs)
batch_size_graphs = self.captured_graphs[bs_key]
candidates = [sl for sl in batch_size_graphs.keys() if sl <= max_seqlen_k]
if not candidates:
return None
best_sl = max(candidates)
return batch_size_graphs[best_sl]
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
from dataclasses import dataclass
import torch
from vllm.forward_context import get_forward_context
from vllm.kvprune.core.compression_bridge import (
COMPRESSION_METHOD_ID_NONE,
compression_method_str_to_id,
)
@dataclass
class KVPruneForwardState:
"""Per-forward-pass state for KV pruning (per-layer logical lengths)."""
active: bool
compression_ratio_gpu: torch.Tensor
"""[num_reqs_padded] ratio in (0,1], 1.0 means no pruning for that row."""
compression_method_id_gpu: torch.Tensor
"""[num_reqs_padded] int32 — see ``compression_bridge`` ids (0=none)."""
query_start_loc: torch.Tensor
"""[num_reqs_padded + 1] int32 on device."""
num_reqs: int
num_reqs_padded: int
num_layers: int
logical_seq_lens_gpu: torch.Tensor
"""Logical KV length per layer (and optionally per KV head).
Shape ``[num_layers, num_reqs_padded]`` or, when ``num_kv_heads > 1``,
``[num_layers, num_reqs_padded, num_kv_heads]`` for per-head lengths.
"""
is_prefill: bool
device: torch.device
def logical_seq_lens_for_layer(self, layer_idx: int) -> torch.Tensor:
sl = self.logical_seq_lens_gpu[layer_idx]
if sl.dim() == 2:
return sl.max(dim=-1).values
return sl
def build_kv_prune_forward_state(
*,
req_ids: list[str],
requests: dict[str, object],
query_start_loc: torch.Tensor,
num_reqs: int,
num_reqs_padded: int,
num_layers: int,
max_num_scheduled_tokens: int,
device: torch.device,
logical_seq_lens_gpu: torch.Tensor,
) -> KVPruneForwardState | None:
"""Build pruning state when any request uses compression_ratio < 1.0."""
if num_reqs <= 0 or num_layers <= 0:
return None
ratios = []
method_ids: list[int] = []
active_req = False
for rid in req_ids[:num_reqs]:
req = requests.get(rid)
sp = getattr(req, "sampling_params", None) if req is not None else None
r = 1.0 if sp is None else float(getattr(sp, "compression_ratio", 1.0))
if r < 1.0 - 1e-6:
active_req = True
ratios.append(r)
if sp is None or r >= 1.0 - 1e-6:
mid = COMPRESSION_METHOD_ID_NONE
else:
cm = getattr(sp, "compression_method", "none") or "none"
mid = compression_method_str_to_id(str(cm))
method_ids.append(mid)
if not active_req:
return None
compression_ratio_gpu = torch.ones(
(num_reqs_padded,), dtype=torch.float32, device=device
)
compression_ratio_gpu[:num_reqs] = torch.tensor(
ratios, dtype=torch.float32, device=device
)
compression_method_id_gpu = torch.zeros(
(num_reqs_padded,), dtype=torch.int32, device=device
)
compression_method_id_gpu[:num_reqs] = torch.tensor(
method_ids, dtype=torch.int32, device=device
)
is_prefill = max_num_scheduled_tokens > 1
return KVPruneForwardState(
active=True,
compression_ratio_gpu=compression_ratio_gpu,
compression_method_id_gpu=compression_method_id_gpu,
query_start_loc=query_start_loc,
num_reqs=num_reqs,
num_reqs_padded=num_reqs_padded,
num_layers=num_layers,
logical_seq_lens_gpu=logical_seq_lens_gpu,
is_prefill=is_prefill,
device=device,
)
def layer_index_from_layer_name(layer_name: str) -> int:
from vllm.model_executor.models.utils import extract_layer_index
return extract_layer_index(layer_name)
def get_kv_prune_state() -> KVPruneForwardState | None:
try:
fc = get_forward_context()
except AssertionError:
return None
state = fc.additional_kwargs.get("kv_prune")
if state is None or not isinstance(state, KVPruneForwardState) or not state.active:
return None
return state
import time
from typing import Iterable, List
from vllm.kvprune.core.memory_manager import KVCacheManager
from vllm.kvprune.utils.sequence import Sequence, SequenceStatus
from tqdm import tqdm
def cdiv(a, b):
"""ceiling division"""
return (a + b - 1) // b
class Scheduler:
"""
Simple sequence scheduler for prefill + decode with a paged KV cache.
The scheduler tracks three disjoint sets of sequence IDs:
* ``pending_sequence_ids`` – sequences that have not yet been started.
* ``active_sequence_ids`` – sequences currently running.
* ``finished_sequence_ids`` – sequences that have generated all tokens.
At prefill time, :meth:`get_prefill_batch` selects a subset of pending
sequences that can fit into the available KV cache and per-step token
budget, given the constraints from the associated :class:`KVCacheManager`.
The class also handles basic bookkeeping of sequence statuses.
Args:
:param all_sequences:
Iterable of :class:`Sequence` objects to be scheduled. Each
sequence must have a unique ``seq_id``.
:param kv_manager:
A :class:`KVCacheManager` instance that this scheduler will use
to determine whether additional batches can be scheduled.
:param use_tqdm:
If True, two progress bars are created:
* "Started Batches" – increments when a sequence moves from
pending to running.
* "Finished Batches" – increments when a sequence finishes.
"""
def __init__(
self,
all_sequences: Iterable[Sequence],
kv_manager: KVCacheManager,
*,
use_tqdm=False,
):
self.allseq_mapping: dict[int, Sequence] = {s.seq_id: s for s in all_sequences}
self.pending_sequence_ids: set[int] = set([s.seq_id for s in all_sequences])
self.active_sequence_ids: set[int] = set()
self.finished_sequence_ids: set[int] = set()
self.manager = kv_manager
self.use_tqdm = use_tqdm
self.start_time = time.perf_counter()
self.total_tokens_generated = 0
self.total_tokens_input = 0
self.pbar = None
if use_tqdm:
self.pbar = tqdm(
total=len(self.pending_sequence_ids),
desc="Completed Batches",
)
def get_prefill_batch(self) -> List[Sequence]:
"""
Select a batch of pending sequences to prefill under KV/memory constraints.
The selection is greedy over ``pending_sequence_ids`` in iteration order.
A sequence is added to the batch if:
* The sum of its prompt length and the total prompt tokens selected so
far does not exceed ``manager.max_batched_tokens``, and
* There is at least one free KV "batch slot" left
(``manager.num_free_batches``), and
* The total number of KV pages required by the sequence's prompt +
max_new_tokens does not exceed the remaining free pages.
Returns:
:return List[Sequence]:
The list of :class:`Sequence` objects chosen for prefill in
this step. The caller is responsible for marking them as
active via :meth:`add_running_sequence_ids`.
"""
total_tok, sequences = 0, []
num_free_batches, num_free_pages = (
self.manager.num_free_batches,
self.manager.num_free_pages,
)
for seq_id in self.pending_sequence_ids:
seq = self.allseq_mapping[seq_id]
prompt_length = seq.prompt_len
pages_needed = (
cdiv(
prompt_length + seq.sampling_params.max_new_tokens,
self.manager.page_size,
)
* self.manager.num_kv_heads
)
if (
prompt_length + total_tok <= self.manager.max_batched_tokens
and num_free_batches > 0
and pages_needed <= num_free_pages
):
sequences.append(seq)
total_tok += prompt_length
num_free_pages -= pages_needed
num_free_batches -= 1
return sequences
def diagnose_prefill_failure(self) -> str:
"""Explain why :meth:`get_prefill_batch` may return empty (debugging)."""
num_free_batches = self.manager.num_free_batches
num_free_pages = self.manager.num_free_pages
parts = [
f"num_free_batches={num_free_batches}",
f"num_free_pages={num_free_pages}",
f"num_pages_per_layer={getattr(self.manager, 'num_pages', None)}",
]
seq_id = next(iter(self.pending_sequence_ids), None)
if seq_id is None:
return "; ".join(parts)
seq = self.allseq_mapping[seq_id]
pl = seq.prompt_len
mn = seq.sampling_params.max_new_tokens
pages_needed = (
cdiv(pl + mn, self.manager.page_size) * self.manager.num_kv_heads
)
parts.append(
f"first_pending seq_id={seq_id} prompt_len={pl} max_new_tokens={mn} "
f"pages_needed~={pages_needed}"
)
if num_free_batches == 0:
parts.append(
"likely_cause=no free batch slots (compactor max_num_seqs exhausted)"
)
elif pl > self.manager.max_batched_tokens:
parts.append(
f"likely_cause=prompt_len ({pl}) > max_batched_tokens "
f"({self.manager.max_batched_tokens})"
)
elif pages_needed > num_free_pages:
parts.append(
"likely_cause=KV pool too small: pages_needed exceeds num_free_pages "
"(raise VLLM_KVPRUNE_COMPACTOR_KV_FREE_FRAC / lower v1 memory, or cap "
"compactor max_num_seqs to shrink page-table overhead)"
)
else:
parts.append(
"likely_cause=batched token sum or greedy order (another sequence may "
"block first in set iteration)"
)
return "; ".join(parts)
def is_finished(self) -> bool:
"""
Check whether all sequences have completed.
"""
return (
len(self.pending_sequence_ids) == 0 and len(self.active_sequence_ids) == 0
)
def any_pending_sequences(self) -> bool:
"""
Check whether any sequences are still pending (not yet started).
"""
return len(self.pending_sequence_ids) != 0
def add_running_sequence_ids(
self, active_sequence_ids: Iterable[int], *, update_status: bool = False
):
"""
Mark a set of sequences as active / running. This moves sequence IDs
from ``pending_sequence_ids`` into ``active_sequence_ids``. Optionally,
it also updates the per-sequence status and progress bar.
Args:
:param active_sequence_ids:
Iterable of sequence IDs that have been scheduled for prefill
or decode and should now be considered running.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.RUNNING`` and increment the
"Started Batches" progress bar if ``use_tqdm`` is enabled.
"""
self.active_sequence_ids.update(active_sequence_ids)
self.pending_sequence_ids.difference_update(self.active_sequence_ids)
if update_status:
for seq_id in active_sequence_ids:
self.allseq_mapping[seq_id].status = SequenceStatus.RUNNING
self.total_tokens_input += self.allseq_mapping[seq_id].prompt_len
def get_finished_sequence_ids_from_unfinished(
self, unfinished_sequence_ids: Iterable[int]
) -> set[int]:
"""
Infer which active sequences have finished given the
unfinished set (for decode steps where the caller knows
which sequences are still generating but not necessarily
which have just completed).
Args:
:param unfinished_sequence_ids:
Iterable of sequence IDs that are still running
Returns:
:return set[int]:
The inferred set of sequence IDs that transitioned from active
to finished.
"""
return self.active_sequence_ids.difference(unfinished_sequence_ids)
def record_finished_sequence_ids(
self, finished_sequence_ids: Iterable[int], *, update_status: bool = False
):
"""
Record that a set of sequences has finished generation.
This moves IDs from ``active_sequence_ids`` into
``finished_sequence_ids``.
Args:
:param finished_sequence_ids:
Iterable of sequence IDs that have completed generation and
no longer require KV cache.
:param update_status:
If True, set each corresponding :class:`Sequence`'s
``status = SequenceStatus.FINISHED``
"""
self.active_sequence_ids.difference_update(finished_sequence_ids)
self.finished_sequence_ids.update(finished_sequence_ids)
if update_status:
for seq_id in finished_sequence_ids:
self.allseq_mapping[seq_id].status = SequenceStatus.FINISHED
if self.pbar is not None:
self.pbar.update(1)
def update_sequences(self, tokens: Iterable[int], seq_ids: Iterable[int]):
"""
Append newly generated tokens to their corresponding sequences.
Args:
:param tokens:
Iterable of generated token IDs, one per sequence.
:param seq_ids:
Iterable of sequence IDs aligned with ``tokens``.
"""
cur_time = time.perf_counter()
for tok, seq_id in zip(tokens, seq_ids):
self.allseq_mapping[seq_id].add_new_token(tok)
self.total_tokens_generated += 1
if self.pbar is not None:
self.pbar.set_description(
f"Throughput: {(self.total_tokens_generated + self.total_tokens_input) / (cur_time - self.start_time):.2f} tok/s"
)
def close(self):
if self.pbar is not None:
self.pbar.close()
def can_prefill_another_batch(self) -> bool:
return len(self.get_prefill_batch()) > 0
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment