Commit 2676ad00 authored by laibao's avatar laibao
Browse files

v1: add SnapKV Triton KV compression

Introduce v1 KV compression modules (budget + SnapKV Triton kernel) and integrate with scheduler/cache managers.
parent 155c8a13
......@@ -141,6 +141,34 @@ if TYPE_CHECKING:
VLLM_USE_NVFP4_CT_EMULATIONS: bool = False
VLLM_ROCM_QUICK_REDUCE_QUANTIZATION: str = "NONE"
VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16: bool = True
# KV compression (token-shared) for v1 paged attention.
# When enabled, vLLM decouples logical positions from KV cache positions
# and keeps only a subset of prompt tokens in KV cache during prefill.
VLLM_ENABLE_KV_COMPRESSION: bool = False
# KV compression policy for selecting which prompt KV entries to retain.
# Currently only "topk" is supported.
VLLM_KV_COMPRESSION_POLICY: str = "topk"
# Target prompt KV budget for token-shared compression.
# If PROMPT_BUDGET >= 0, it takes precedence over PROMPT_RATIO.
# The budget/ratio applies to non-protected prompt tokens only.
VLLM_KV_COMPRESSION_PROMPT_RATIO: float = 1.0
VLLM_KV_COMPRESSION_PROMPT_BUDGET: int = -1
VLLM_KV_COMPRESSION_PROTECTED_PREFIX: int = 0
VLLM_KV_COMPRESSION_PROTECTED_SUFFIX: int = 0
VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN: bool = True
# SnapKV-like scoring wi这个ndow used by the "topk" policy.
VLLM_KV_COMPRESSION_SNAPKV_WINDOW: int = 32
# Use Triton SnapKV scoring on ROCm (experimental). Set to 0 to force the
# PyTorch reference implementation.
VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM: bool = True
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing a single selection across all layers in a forward pass.
VLLM_KV_COMPRESSION_TOPK_PER_LAYER: bool = False
# Run KV compaction writeback (reshape_and_cache_*) on a separate CUDA
# stream to overlap with compute (experimental).
VLLM_KV_COMPRESSION_ASYNC_WRITEBACK: bool = False
# Free unused tail KV cache blocks after prompt compaction (experimental).
VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS: bool = True
VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB: Optional[int] = None
# add envs
......@@ -1055,6 +1083,50 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_PREFIX_FLASH_ATTN", "False").lower() in
("true", "1")),
# Enable token-shared KV compression for v1 paged attention (experimental).
# This feature currently targets long-prompt prefill memory reduction.
"VLLM_ENABLE_KV_COMPRESSION":
lambda: bool(int(os.getenv("VLLM_ENABLE_KV_COMPRESSION", "0"))),
# KV compression policy ("topk").
"VLLM_KV_COMPRESSION_POLICY":
lambda: os.getenv("VLLM_KV_COMPRESSION_POLICY", "topk").lower(),
# Target fraction of non-protected prompt tokens to keep in KV cache.
"VLLM_KV_COMPRESSION_PROMPT_RATIO":
lambda: float(os.getenv("VLLM_KV_COMPRESSION_PROMPT_RATIO", "1.0")),
# Target number of non-protected prompt tokens to keep in KV cache.
# If >= 0, this takes precedence over VLLM_KV_COMPRESSION_PROMPT_RATIO.
"VLLM_KV_COMPRESSION_PROMPT_BUDGET":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROMPT_BUDGET", "-1")),
# Always keep the first N prompt tokens in KV cache (e.g. BOS/system).
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_PREFIX", "0")),
# Always keep the last N prompt tokens in KV cache.
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_PROTECTED_SUFFIX", "0")),
# Always keep the last prompt token (prompt_len - 1) when it is scheduled.
"VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN", "1"))),
# SnapKV-like scoring window size for the "topk" policy.
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW":
lambda: int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_WINDOW", "32")),
# Enable Triton SnapKV scoring on ROCm (experimental).
"VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_SNAPKV_USE_TRITON_ROCM", "1"))),
# If set, compute token-shared Top-K selection per attention layer instead
# of sharing one selection across layers in a forward pass.
"VLLM_KV_COMPRESSION_TOPK_PER_LAYER":
lambda: bool(int(os.getenv("VLLM_KV_COMPRESSION_TOPK_PER_LAYER", "0"))),
# If set, run KV compaction writeback on a separate CUDA stream to overlap
# cache writes with compute (experimental).
"VLLM_KV_COMPRESSION_ASYNC_WRITEBACK":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_ASYNC_WRITEBACK", "0"))),
# If set, free unused tail KV cache blocks after prompt compaction.
"VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS":
lambda: bool(
int(os.getenv("VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS", "0"))),
# If set, vLLM will use optimized MLA attention optimizations.
"VLLM_USE_TRITON_OPT_MLA":
lambda: bool(int(os.getenv("VLLM_USE_TRITON_OPT_MLA", "0"))),
......
This diff is collapsed.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
from typing import Optional, Union
import torch
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import triton
import triton.language as tl
if HAS_TRITON:
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_Q": bq, "BLOCK_K": bk},
num_warps=num_warps,
num_stages=num_stages,
) for bq in [32, 64] for bk in [32, 64] for num_warps in [4, 8]
for num_stages in [3, 4]
],
key=["QUERY_GROUP_SIZE", "D", "ROWS_MAX"],
)
@triton.jit
def _lse_and_store_logits_kernel(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
ROWS_MAX,
):
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
offs_d = tl.arange(0, D)
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_rows = tl.load(q_ptrs, mask=row_mask[:, None], other=0.0)
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_blk = tl.load(k_ptrs, mask=kmask[:, None], other=0.0)
s = tl.dot(q_rows, k_blk.T) * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R, m, mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R, S, mask=row_mask)
@triton.jit
def _lse_and_store_logits_kernel_rocm_safe(
Q,
K,
cu_q,
cu_k,
w_b,
out_m,
out_S,
LOGITS,
sm_scale,
QUERY_GROUP_SIZE: tl.constexpr,
D: tl.constexpr,
STRIDE_Q_NQ,
STRIDE_Q_HQ,
STRIDE_K_NK,
STRIDE_K_HK,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
BLOCK_D: tl.constexpr,
):
"""ROCm-safe variant of `_lse_and_store_logits_kernel`.
On some ROCm + Triton (HIP) stacks we have observed memory corruption
from the tl.dot-based implementation. This variant avoids `tl.dot` and
instead computes dot-products via explicit outer-product accumulation.
"""
b = tl.program_id(0)
hk = tl.program_id(1)
rid = tl.program_id(2)
q_end = tl.load(cu_q + b + 1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
q_win_beg = q_end - win
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
row0 = rid * BLOCK_Q
if row0 >= rows_b:
return
qk_scale = sm_scale * 1.4426950408889634 # exp -> exp2.
offs_qrow = row0 + tl.arange(0, BLOCK_Q)
row_mask = offs_qrow < rows_b
hq_local = offs_qrow % QUERY_GROUP_SIZE
q_off = offs_qrow // QUERY_GROUP_SIZE
q_idx = q_win_beg + q_off
hq_glob = hk * QUERY_GROUP_SIZE + hq_local
m = tl.zeros([BLOCK_Q], dtype=tl.float32) + (-float("inf"))
S = tl.zeros([BLOCK_Q], dtype=tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
# Accumulate s = Q @ K^T in fp32 via outer products.
s = tl.zeros([BLOCK_Q, BLOCK_K], dtype=tl.float32)
for ds in tl.static_range(0, D, BLOCK_D):
offs_d = ds + tl.arange(0, BLOCK_D)
dmask = offs_d < D
q_ptrs = (Q + q_idx[:, None] * STRIDE_Q_NQ +
hq_glob[:, None] * STRIDE_Q_HQ + offs_d[None, :])
q_chunk = tl.load(
q_ptrs,
mask=row_mask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BQ, BD]
k_ptrs = (K + nk[:, None] * STRIDE_K_NK +
hk * STRIDE_K_HK + offs_d[None, :])
k_chunk = tl.load(
k_ptrs,
mask=kmask[:, None] & dmask[None, :],
other=0.0,
).to(tl.float32) # [BK, BD]
s += tl.sum(q_chunk[:, None, :] * k_chunk[None, :, :], axis=2)
s = s * qk_scale
s = tl.where(kmask[None, :], s, -float("inf"))
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] * STRIDE_LG_R)
tl.store(log_ptrs, s.T, mask=kmask[:, None] & row_mask[None, :])
cur_max = tl.max(s, 1)
n_m = tl.maximum(m, cur_max)
rescale = tl.math.exp2(m - n_m)
S = S * rescale + tl.sum(tl.math.exp2(s - n_m[:, None]), 1)
m = n_m
m_base = out_m + b * STRIDE_M_B + hk * STRIDE_M_H + row0 * STRIDE_M_R
S_base = out_S + b * STRIDE_S_B + hk * STRIDE_S_H + row0 * STRIDE_S_R
tl.store(m_base + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
m,
mask=row_mask)
tl.store(S_base + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
S,
mask=row_mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_Q": bq, "BLOCK_K": bk})
for bq in [16, 32, 64] for bk in [32, 64, 128]
],
key=["HK", "HQ"],
)
@triton.jit
def _scores_from_logits_kernel(
cu_k,
w_b,
in_m,
in_S,
LOGITS,
OUT,
QUERY_GROUP_SIZE: tl.constexpr,
STRIDE_M_B,
STRIDE_M_H,
STRIDE_M_R,
STRIDE_S_B,
STRIDE_S_H,
STRIDE_S_R,
STRIDE_LG_NK,
STRIDE_LG_HK,
STRIDE_LG_R,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
BLOCK_Q: tl.constexpr,
BLOCK_K: tl.constexpr,
DO_POOL: tl.constexpr,
KPOOL: tl.constexpr,
PROTECT_LAST: tl.constexpr,
):
b = tl.program_id(0)
hk = tl.program_id(1)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if (win <= 0) or (k_eff_end <= k_beg):
return
rows_b = win * QUERY_GROUP_SIZE
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
scores = tl.zeros([BLOCK_K], dtype=tl.float32)
for row0 in tl.range(0, rows_b, BLOCK_Q):
r_idx = row0 + tl.arange(0, BLOCK_Q)
rmask = r_idx < rows_b
m_ptr = (in_m + b * STRIDE_M_B + hk * STRIDE_M_H +
row0 * STRIDE_M_R)
S_ptr = (in_S + b * STRIDE_S_B + hk * STRIDE_S_H +
row0 * STRIDE_S_R)
m = tl.load(m_ptr + tl.arange(0, BLOCK_Q) * STRIDE_M_R,
mask=rmask,
other=-float("inf"))
S = tl.load(S_ptr + tl.arange(0, BLOCK_Q) * STRIDE_S_R,
mask=rmask,
other=0.0)
valid_row = S > 0
m = tl.where(valid_row, m, 0.0)
S = tl.where(valid_row, S, 1.0)
log_ptrs = (LOGITS + nk[:, None] * STRIDE_LG_NK +
hk * STRIDE_LG_HK +
(row0 + tl.arange(0, BLOCK_Q))[None, :] *
STRIDE_LG_R)
s_T = tl.load(log_ptrs,
mask=kmask[:, None] & rmask[None, :],
other=-float("inf"))
probs_T = tl.math.exp2(s_T - m[None, :]) / S[None, :]
probs_T = tl.where(valid_row[None, :], probs_T, 0.0)
scores += tl.sum(probs_T, 1)
if DO_POOL and (KPOOL > 1):
i = tl.arange(0, BLOCK_K)[:, None]
j = tl.arange(0, BLOCK_K)[None, :]
band = (j <= i) & ((i - j) < KPOOL)
band = band & kmask[None, :]
sums = tl.sum(tl.where(band, scores[None, :], 0.0), 1)
denom = tl.sum(band, 1).to(tl.float32)
denom = tl.where(denom > 0, denom, 1.0)
scores = sums / denom
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs, scores, mask=kmask)
if PROTECT_LAST:
pad_beg = k_eff_end
pad_end = k_end
if pad_end > pad_beg:
for ks in tl.range(pad_beg, pad_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < pad_end
out_ptrs = OUT + nk * STRIDE_OUT_NK + hk * STRIDE_OUT_HK
tl.store(out_ptrs,
tl.full([BLOCK_K], float("inf"), dtype=tl.float32),
mask=kmask)
@triton.autotune(
configs=[triton.Config({"BLOCK_K": bk}) for bk in [32, 64, 128]],
key=["HK"],
)
@triton.jit
def _zscore_per_batch_epilogue(
OUT,
cu_k,
w_b,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK: tl.constexpr,
EPS: tl.constexpr,
BLOCK_K: tl.constexpr,
):
b = tl.program_id(0)
k_beg = tl.load(cu_k + b)
k_end = tl.load(cu_k + b + 1)
win = tl.load(w_b + b)
k_eff_end = k_end - win
if k_eff_end <= k_beg:
return
sumv = tl.zeros([], dtype=tl.float32)
sumsq = tl.zeros([], dtype=tl.float32)
count = ((k_eff_end - k_beg) * HK).to(tl.float32)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
sumv += tl.sum(vals, 0)
sumsq += tl.sum(vals * vals, 0)
mean = sumv / count
var = tl.maximum(sumsq / count - mean * mean, 0.0)
invstd = 1.0 / tl.sqrt(var + EPS)
for ks in tl.range(k_beg, k_eff_end, BLOCK_K):
nk = ks + tl.arange(0, BLOCK_K)
kmask = nk < k_eff_end
for h in tl.range(0, HK):
ptrs = OUT + nk * STRIDE_OUT_NK + h * STRIDE_OUT_HK
vals = tl.load(ptrs, mask=kmask, other=0.0).to(tl.float32)
vals = (vals - mean) * invstd
tl.store(ptrs, vals, mask=kmask)
def query_aware_key_scores(
q: torch.Tensor, # [N_q, Hq, D]
k: torch.Tensor, # [N_k, Hk, D]
cu_seqlens_q: torch.Tensor, # [B+1] int32
cu_seqlens_k: torch.Tensor, # [B+1] int32
w: Union[int, torch.Tensor], # [B] int32 or scalar
sm_scale: Optional[float] = None,
*,
pool: bool = True,
kpool: int = 5,
protect_last: bool = True,
normalize: bool = False,
) -> torch.Tensor:
"""SnapKV query-aware key scores (Triton), returns [N_k, Hk] float32."""
if not HAS_TRITON:
raise RuntimeError("Triton is not available.")
if q.device.type != "cuda" or k.device.type != "cuda":
raise RuntimeError("Triton SnapKV requires CUDA/ROCm tensors.")
if q.ndim != 3 or k.ndim != 3:
raise ValueError("q and k must be 3D tensors.")
if q.stride(-1) != 1 or k.stride(-1) != 1:
raise ValueError("Last dim must be contiguous for Triton SnapKV.")
device = q.device
N_q, Hq, D = q.shape
N_k, Hk, Dk = k.shape
if D != Dk:
raise ValueError("q and k must have the same head size.")
if (Hq % Hk) != 0:
raise ValueError("Hq must be a multiple of Hk.")
if sm_scale is None:
sm_scale = 1.0 / math.sqrt(D)
B = int(cu_seqlens_q.numel() - 1)
if B != int(cu_seqlens_k.numel() - 1):
raise ValueError("cu_seqlens_q and cu_seqlens_k must match.")
G = Hq // Hk
if isinstance(w, int):
max_w = int(w)
w = torch.full((B, ),
fill_value=max_w,
device=device,
dtype=torch.int32)
else:
if w.numel() != B:
raise ValueError("w must have shape [B].")
w = w.to(device=device, dtype=torch.int32)
max_w = int(w.max().item())
rows_max = max_w * G
if rows_max <= 0:
return torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
if kpool < 1:
raise ValueError("kpool must be >= 1.")
out = torch.zeros((N_k, Hk), dtype=torch.float32, device=device)
m_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
S_scratch = torch.empty((B, Hk, rows_max),
dtype=torch.float32,
device=device)
logits_buf = torch.empty((N_k, Hk, rows_max),
dtype=torch.float32,
device=device)
STRIDE_Q_NQ, STRIDE_Q_HQ, _ = q.stride()
STRIDE_K_NK, STRIDE_K_HK, _ = k.stride()
STRIDE_M_B, STRIDE_M_H, STRIDE_M_R = m_scratch.stride()
STRIDE_S_B, STRIDE_S_H, STRIDE_S_R = S_scratch.stride()
STRIDE_LG_NK, STRIDE_LG_HK, STRIDE_LG_R = logits_buf.stride()
STRIDE_OUT_NK, STRIDE_OUT_HK = out.stride()
def grid(meta):
return B, Hk, triton.cdiv(rows_max, meta["BLOCK_Q"])
cu_q = cu_seqlens_q.to(device=device, dtype=torch.int32)
cu_k = cu_seqlens_k.to(device=device, dtype=torch.int32)
# NOTE: On ROCm/HIP, we prefer a dot-free kernel variant to avoid known
# correctness issues (silent memory corruption) observed with the tl.dot
# implementation on some stacks.
is_rocm = getattr(torch.version, "hip", None) is not None
if is_rocm:
_lse_and_store_logits_kernel_rocm_safe[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
BLOCK_Q=32,
BLOCK_K=32,
BLOCK_D=16,
num_warps=4,
num_stages=1,
)
else:
_lse_and_store_logits_kernel[grid](
q,
k,
cu_q,
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
sm_scale,
QUERY_GROUP_SIZE=G,
D=D,
STRIDE_Q_NQ=STRIDE_Q_NQ,
STRIDE_Q_HQ=STRIDE_Q_HQ,
STRIDE_K_NK=STRIDE_K_NK,
STRIDE_K_HK=STRIDE_K_HK,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
ROWS_MAX=rows_max,
)
_scores_from_logits_kernel[(B, Hk)](
cu_k,
w,
m_scratch,
S_scratch,
logits_buf,
out,
QUERY_GROUP_SIZE=G,
STRIDE_M_B=STRIDE_M_B,
STRIDE_M_H=STRIDE_M_H,
STRIDE_M_R=STRIDE_M_R,
STRIDE_S_B=STRIDE_S_B,
STRIDE_S_H=STRIDE_S_H,
STRIDE_S_R=STRIDE_S_R,
STRIDE_LG_NK=STRIDE_LG_NK,
STRIDE_LG_HK=STRIDE_LG_HK,
STRIDE_LG_R=STRIDE_LG_R,
STRIDE_OUT_NK=STRIDE_OUT_NK,
STRIDE_OUT_HK=STRIDE_OUT_HK,
DO_POOL=pool,
KPOOL=kpool,
PROTECT_LAST=protect_last,
)
if normalize:
_zscore_per_batch_epilogue[(B, )](
out,
cu_k,
w,
STRIDE_OUT_NK,
STRIDE_OUT_HK,
HK=Hk,
EPS=1e-12,
)
return out
......@@ -154,6 +154,17 @@ class KVCacheCoordinator(ABC):
for manager in self.single_type_managers:
manager.remove_skipped_blocks(request_id, num_computed_tokens)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
Returns True if any blocks were freed.
"""
truncated = False
for manager in self.single_type_managers:
truncated = manager.truncate_to_num_tokens(request_id,
num_tokens) or truncated
return truncated
def get_blocks(self, request_id: str) -> tuple[list[KVCacheBlock], ...]:
"""
Get the blocks for the request.
......
......@@ -7,6 +7,8 @@ from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
import vllm.envs as envs
from vllm.platforms import current_platform
from vllm.utils import sha256
from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
......@@ -251,6 +253,14 @@ class KVCacheManager:
# the new prefix caching hits
num_computed_tokens = (request.num_computed_tokens +
num_new_computed_tokens)
if envs.VLLM_ENABLE_KV_COMPRESSION and not current_platform.is_tpu():
# KV compression decouples logical positions from KV cache
# positions. Allocate based on the KV cache length (plus the tokens
# scheduled for this step, which are temporarily written to cache).
num_tokens_need_slot = min(
request.num_kv_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
else:
num_tokens_need_slot = min(
num_computed_tokens + num_new_tokens + num_lookahead_tokens,
self.max_model_len)
......@@ -385,6 +395,14 @@ class KVCacheManager:
return KVCacheBlocks(
self.coordinator.get_blocks(request_id)).get_block_ids()
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort operation that may free blocks back to the pool.
Returns True if any blocks were freed.
"""
return self.coordinator.truncate_to_num_tokens(request_id, num_tokens)
def cache_blocks(self, request: Request, num_computed_tokens: int) -> None:
"""Cache the blocks for the request, if enabled."""
if self.enable_caching:
......
......@@ -31,6 +31,7 @@ class NewRequestData:
pooling_params: Optional[PoolingParams]
block_ids: tuple[list[int], ...]
num_computed_tokens: int
num_kv_tokens: int
lora_request: Optional[LoRARequest]
@classmethod
......@@ -49,6 +50,7 @@ class NewRequestData:
pooling_params=request.pooling_params,
block_ids=block_ids,
num_computed_tokens=request.num_computed_tokens,
num_kv_tokens=request.num_kv_tokens,
lora_request=request.lora_request,
)
......@@ -62,6 +64,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}"
")")
......@@ -76,6 +79,7 @@ class NewRequestData:
f"sampling_params={self.sampling_params},"
f"block_ids={self.block_ids},"
f"num_computed_tokens={self.num_computed_tokens},"
f"num_kv_tokens={self.num_kv_tokens},"
f"lora_request={self.lora_request}"
")")
......@@ -93,6 +97,7 @@ class CachedRequestData:
new_token_ids: list[list[int]]
new_block_ids: list[tuple[list[int], ...]]
num_computed_tokens: list[int]
num_kv_tokens: list[int]
@property
def num_reqs(self) -> int:
......@@ -106,6 +111,7 @@ class CachedRequestData:
new_token_ids=[],
new_block_ids=[],
num_computed_tokens=[],
num_kv_tokens=[],
)
......
......@@ -28,12 +28,15 @@ from vllm.v1.core.sched.request_queue import (SchedulingPolicy,
from vllm.v1.core.sched.utils import check_stop
from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.kv_cache_interface import KVCacheConfig, SlidingWindowSpec
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.kv_compression.budget import (compute_topk_budget_step,
count_prompt_must_keep_in_range)
from vllm.platforms import current_platform
from vllm import envs
logger = init_logger(__name__)
......@@ -156,6 +159,50 @@ class Scheduler(SchedulerInterface):
self.compilation_config = vllm_config.compilation_config
self.full_cuda_graph = self.compilation_config.full_cuda_graph
self.use_mla = vllm_config.model_config.use_mla
# KV compression is a GPU-only feature in this fork; ignore it on TPU.
self.kv_compression_enabled = (envs.VLLM_ENABLE_KV_COMPRESSION
and not current_platform.is_tpu())
if envs.VLLM_ENABLE_KV_COMPRESSION and current_platform.is_tpu():
logger.warning_once(
"KV compression is not supported on TPU; ignoring "
"VLLM_ENABLE_KV_COMPRESSION=1.")
if self.kv_compression_enabled:
if envs.VLLM_KV_COMPRESSION_POLICY != "topk":
raise ValueError(
"VLLM_KV_COMPRESSION_POLICY must be 'topk'.")
if any(
isinstance(group.kv_cache_spec, SlidingWindowSpec)
for group in kv_cache_config.kv_cache_groups):
raise ValueError(
"KV compression is incompatible with sliding window "
"attention.")
if self.cache_config.enable_prefix_caching:
raise ValueError(
"KV compression is incompatible with prefix caching. "
"Disable prefix caching to enable KV compression.")
if self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
if self.speculative_config is not None:
raise ValueError(
"KV compression is currently incompatible with "
"speculative decoding.")
if envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET < -1:
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_BUDGET must be >= -1.")
if not (0.0 <= envs.VLLM_KV_COMPRESSION_PROMPT_RATIO <= 1.0):
raise ValueError(
"VLLM_KV_COMPRESSION_PROMPT_RATIO must be in [0, 1].")
if envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_PREFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX < 0:
raise ValueError(
"VLLM_KV_COMPRESSION_PROTECTED_SUFFIX must be >= 0.")
if envs.VLLM_KV_COMPRESSION_SNAPKV_WINDOW < 1:
raise ValueError(
"VLLM_KV_COMPRESSION_SNAPKV_WINDOW must be >= 1.")
# Create the KV cache manager.
self.kv_cache_manager = KVCacheManager(
......@@ -207,6 +254,8 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# For logging.
scheduled_timestamp = time.monotonic()
......@@ -274,6 +323,13 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens == request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
......@@ -295,6 +351,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -321,6 +378,10 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
......@@ -532,6 +593,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -586,6 +649,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
......@@ -645,6 +709,16 @@ class Scheduler(SchedulerInterface):
encoder_budget = self.max_num_encoder_input_tokens
# Spec decode-related.
scheduled_spec_decode_tokens: dict[str, list[int]] = {}
# Requests whose block IDs must be replaced (not appended) in workers.
force_replace_block_ids: set[str] = set()
# Track the LoRAs in this step to respect max_loras when scheduling
# waiting requests first.
scheduled_loras: set[int] = set()
if self.lora_config:
scheduled_loras = set(
req.lora_request.lora_int_id for req in self.running
if req.lora_request and req.lora_request.lora_int_id > 0)
assert len(scheduled_loras) <= self.lora_config.max_loras
# For logging.
scheduled_timestamp = time.monotonic()
......@@ -826,6 +900,8 @@ class Scheduler(SchedulerInterface):
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
if not self.kv_compression_enabled:
request.num_kv_tokens = num_computed_tokens
# Count the number of prefix cached tokens.
if request.num_cached_tokens < 0:
request.num_cached_tokens = num_computed_tokens
......@@ -894,6 +970,14 @@ class Scheduler(SchedulerInterface):
num_new_tokens + request.num_computed_tokens -
request.num_tokens, 0)
if (self.kv_compression_enabled
and envs.VLLM_KV_COMPRESSION_FREE_TAIL_BLOCKS
and request.num_computed_tokens
== request.num_prompt_tokens
and self.kv_cache_manager.truncate_to_num_tokens(
request.request_id, request.num_kv_tokens)):
force_replace_block_ids.add(request.request_id)
while True:
new_blocks = self.kv_cache_manager.allocate_slots(
request,
......@@ -915,6 +999,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
preempted_req.num_kv_tokens = 0
if self.log_stats:
preempted_req.record_event(
EngineCoreEventType.PREEMPTED, scheduled_timestamp)
......@@ -941,6 +1026,10 @@ class Scheduler(SchedulerInterface):
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[request.request_id] = req_index
if request.request_id in force_replace_block_ids:
req_to_new_block_ids[request.request_id] = (
self.kv_cache_manager.get_block_ids(request.request_id))
else:
req_to_new_block_ids[request.request_id] = (
new_blocks.get_block_ids())
num_scheduled_tokens[request.request_id] = num_new_tokens
......@@ -1014,6 +1103,7 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens,
scheduled_spec_decode_tokens,
req_to_new_block_ids,
force_replace_block_ids=force_replace_block_ids,
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
......@@ -1076,7 +1166,50 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
for req_id, num_scheduled_token in num_scheduled_tokens.items():
request = self.requests[req_id]
start_pos = request.num_computed_tokens
request.num_computed_tokens += num_scheduled_token
if not self.kv_compression_enabled:
# Keep KV length in sync with logical length when compression
# is disabled (default vLLM behavior).
request.num_kv_tokens += num_scheduled_token
continue
# When KV compression is enabled, only keep a subset of prompt
# tokens. Decode tokens are always kept.
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
end_pos = request.num_computed_tokens
prompt_end = request.num_prompt_tokens
# Decode token(s): keep all.
decode_start = max(start_pos, prompt_end)
kept_decode = max(0, end_pos - decode_start)
kept_prompt_must_keep = count_prompt_must_keep_in_range(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
)
kept_prompt_topk = compute_topk_budget_step(
prompt_len=prompt_end,
start_pos=start_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
request.num_kv_tokens += (
kept_decode + kept_prompt_must_keep + kept_prompt_topk)
# Clear the finished request IDs.
......@@ -1091,11 +1224,16 @@ class Scheduler(SchedulerInterface):
num_scheduled_tokens: dict[str, int],
spec_decode_tokens: dict[str, list[int]],
req_to_new_block_ids: dict[str, tuple[list[int], ...]],
*,
force_replace_block_ids: Optional[set[str]] = None,
) -> CachedRequestData:
req_ids: list[str] = []
new_token_ids: list[list[int]] = []
new_block_ids: list[tuple[list[int], ...]] = []
num_computed_tokens: list[int] = []
num_kv_tokens: list[int] = []
resumed_from_preemption: list[bool] = []
force_replace_block_ids = force_replace_block_ids or set()
for req in itertools.chain(running_reqs, resumed_reqs):
req_id = req.request_id
......@@ -1111,10 +1249,9 @@ class Scheduler(SchedulerInterface):
new_token_ids.append(token_ids)
new_block_ids.append(req_to_new_block_ids[req_id])
num_computed_tokens.append(req.num_computed_tokens)
# Because resumed_reqs is usually empty, it is more efficient to do
# in-place appending so that we don't need to allocate a new list.
resumed_from_preemption = [False] * len(running_reqs)
resumed_from_preemption += [True] * len(resumed_reqs)
num_kv_tokens.append(req.num_kv_tokens)
resumed_from_preemption.append(
(req in resumed_reqs) or (req_id in force_replace_block_ids))
return CachedRequestData(
req_ids=req_ids,
......@@ -1122,6 +1259,7 @@ class Scheduler(SchedulerInterface):
new_token_ids=new_token_ids,
new_block_ids=new_block_ids,
num_computed_tokens=num_computed_tokens,
num_kv_tokens=num_kv_tokens,
)
def _try_schedule_encoder_inputs(
......@@ -1567,6 +1705,7 @@ class Scheduler(SchedulerInterface):
# Update the request state for scheduling.
request.num_computed_tokens = num_computed_tokens
request.num_kv_tokens = num_computed_tokens
# Return that we are ready.
self.finished_recving_kv_req_ids.remove(request.request_id)
......
......@@ -174,6 +174,15 @@ class SingleTypeKVCacheManager(ABC):
self.block_pool.free_blocks(ordered_blocks)
self.num_cached_block.pop(request_id, None)
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
"""Truncate a request's allocated blocks to fit `num_tokens` slots.
This is a best-effort optimization hook. Subclasses may override this
to free no-longer-needed blocks (e.g., after KV compaction). The default
implementation is a no-op.
"""
return False
@abstractmethod
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
......@@ -283,6 +292,24 @@ class FullAttentionManager(SingleTypeKVCacheManager):
# No need to remove blocks for full attention.
pass
def truncate_to_num_tokens(self, request_id: str, num_tokens: int) -> bool:
num_tokens = max(int(num_tokens), 0)
blocks = self.req_to_blocks.get(request_id)
if not blocks:
return False
num_required_blocks = cdiv(num_tokens, self.block_size)
if num_required_blocks >= len(blocks):
return False
removed_blocks = blocks[num_required_blocks:]
del blocks[num_required_blocks:]
self.block_pool.free_blocks(reversed(removed_blocks))
if request_id in self.num_cached_block:
self.num_cached_block[request_id] = min(
self.num_cached_block[request_id], len(blocks))
return True
def get_num_common_prefix_blocks(self, request_id: str,
num_running_requests: int) -> int:
blocks = self.req_to_blocks[request_id]
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from .budget import ( # noqa: F401
compute_topk_budget_step,
count_prompt_must_keep_in_range,
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
def _clamp_int(value: int, lo: int, hi: int) -> int:
if value < lo:
return lo
if value > hi:
return hi
return value
def _intersection_len(a0: int, a1: int, b0: int, b1: int) -> int:
start = a0 if a0 > b0 else b0
end = a1 if a1 < b1 else b1
return max(0, end - start)
def _protected_prefix_len(prompt_len: int, protected_prefix: int) -> int:
return min(max(protected_prefix, 0), max(prompt_len, 0))
def _protected_suffix_start(prompt_len: int, protected_suffix: int) -> int:
prompt_len = max(prompt_len, 0)
suffix = min(max(protected_suffix, 0), prompt_len)
return prompt_len - suffix
def count_prompt_must_keep_in_range(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt tokens in [start_pos, end_pos) that are always kept."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
start = _clamp_int(start_pos, 0, prompt_len)
end = _clamp_int(end_pos, 0, prompt_len)
if end <= start:
return 0
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
keep_prefix = _intersection_len(start, end, 0, prefix_len)
keep_suffix = _intersection_len(start, end, suffix_start, prompt_len)
overlap = _intersection_len(start, end, suffix_start, prefix_len)
kept = keep_prefix + keep_suffix - overlap
if keep_last_token:
last = prompt_len - 1
if start <= last < end:
already_kept = (last < prefix_len) or (last >= suffix_start)
if not already_kept:
kept += 1
return kept
def _count_prompt_candidates_upto(
*,
prompt_len: int,
pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
"""Count prompt candidates in [0, pos) eligible for Top-K selection."""
prompt_len = max(prompt_len, 0)
if prompt_len == 0:
return 0
x = _clamp_int(pos, 0, prompt_len)
prefix_len = _protected_prefix_len(prompt_len, protected_prefix)
suffix_start = _protected_suffix_start(prompt_len, protected_suffix)
mid_end = min(x, suffix_start)
cand = max(0, mid_end - min(prefix_len, mid_end))
if keep_last_token:
last = prompt_len - 1
if prefix_len <= last < mid_end:
cand -= 1
return max(cand, 0)
def _candidate_total(
*,
prompt_len: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
) -> int:
return _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
def _candidate_keep_total(
*,
candidate_total: int,
prompt_ratio: float,
prompt_budget: int,
) -> int:
if candidate_total <= 0:
return 0
if prompt_budget >= 0:
return min(prompt_budget, candidate_total)
ratio = max(0.0, min(float(prompt_ratio), 1.0))
keep = int(math.floor(candidate_total * ratio + 0.5))
return _clamp_int(keep, 0, candidate_total)
def compute_topk_budget_step(
*,
prompt_len: int,
start_pos: int,
end_pos: int,
protected_prefix: int,
protected_suffix: int,
keep_last_token: bool,
prompt_ratio: float,
prompt_budget: int,
) -> int:
"""Compute how many prompt candidate tokens to select for this step.
The budget applies to the *non-protected* prompt region and is distributed
across multiple prefill steps using a prefix-proportional rule:
budget_upto(x) = floor(total_keep * candidates_upto(x) / candidates_total)
The step's budget is the delta between its end and start positions.
"""
total = _candidate_total(
prompt_len=prompt_len,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
if total <= 0:
return 0
total_keep = _candidate_keep_total(
candidate_total=total,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
if total_keep <= 0:
return 0
cand_upto_start = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=start_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
cand_upto_end = _count_prompt_candidates_upto(
prompt_len=prompt_len,
pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last_token,
)
step_total = max(0, cand_upto_end - cand_upto_start)
if step_total == 0:
return 0
bud_upto_start = (total_keep * cand_upto_start) // total
bud_upto_end = (total_keep * cand_upto_end) // total
step_keep = bud_upto_end - bud_upto_start
return _clamp_int(step_keep, 0, step_total)
......@@ -79,6 +79,10 @@ class Request:
self._all_token_ids: list[int] = self.prompt_token_ids.copy()
self.spec_token_ids: list[int] = []
self.num_computed_tokens = 0
# Number of tokens currently stored in the KV cache for this request.
# This can be different from `num_computed_tokens` when KV compression
# is enabled (e.g., token-shared prefill compression).
self.num_kv_tokens = 0
self.num_generated_token_ids = 0
self.cache_salt: Optional[str] = cache_salt
......
......@@ -63,6 +63,11 @@ class BlockTable:
def add_row(self, block_ids: list[int], row_idx: int) -> None:
self.num_blocks_per_row[row_idx] = 0
# Keep the invariant that "unused" entries map to the null block (id=0).
# This matters when we *shrink* a request's block list (e.g. KV
# compression tail-block truncation) and later re-use freed blocks for
# other requests.
self.block_table_np[row_idx, :].fill(0)
self.append_row(block_ids, row_idx)
def move_row(self, src: int, tgt: int) -> None:
......
......@@ -38,6 +38,7 @@ class CachedRequestState:
block_ids: tuple[list[int], ...]
num_computed_tokens: int
num_kv_tokens: int
output_token_ids: list[int]
spec_token_ids: list[int] = None
......@@ -114,6 +115,13 @@ class InputBatch:
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
self.num_kv_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_kv_tokens_cpu = self.num_kv_tokens_cpu_tensor.numpy()
# Block table.
self.block_table = MultiGroupBlockTable(
......@@ -348,6 +356,7 @@ class InputBatch:
self.num_tokens_no_spec[req_index] = request.num_tokens
self.num_computed_tokens_cpu[req_index] = request.num_computed_tokens
self.num_kv_tokens_cpu[req_index] = request.num_kv_tokens
self.block_table.add_row(request.block_ids, req_index)
if sampling_params := request.sampling_params:
......@@ -504,6 +513,8 @@ class InputBatch:
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.num_kv_tokens_cpu[i1], self.num_kv_tokens_cpu[i2] =\
self.num_kv_tokens_cpu[i2], self.num_kv_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
......@@ -602,6 +613,8 @@ class InputBatch:
last_req_index]
self.num_computed_tokens_cpu[
empty_index] = self.num_computed_tokens_cpu[last_req_index]
self.num_kv_tokens_cpu[
empty_index] = self.num_kv_tokens_cpu[last_req_index]
self.block_table.move_row(last_req_index, empty_index)
self.temperature_cpu[empty_index] = self.temperature_cpu[
last_req_index]
......
......@@ -55,6 +55,7 @@ from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec,
KVCacheConfig, KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.kv_compression.budget import compute_topk_budget_step
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
......@@ -146,6 +147,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.attention_chunk_size = model_config.attention_chunk_size
self.cascade_attn_enabled = not self.model_config.disable_cascade_attn
if envs.VLLM_ENABLE_KV_COMPRESSION:
# KV compression changes the effective KV sequence layout and
# invalidates cascade attention assumptions (common-prefix blocks).
self.cascade_attn_enabled = False
# Whether the current step needs KV compaction work (score/topk/dst).
# This is set per-step in `_prepare_inputs`.
self.kv_compression_needs_compaction: bool = False
# Multi-modal data support
self.mm_registry = MULTIMODAL_REGISTRY
......@@ -313,6 +321,14 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.positions_np = self.positions_cpu.numpy()
# KV positions are decoupled from logical positions when KV compression
# is enabled. We keep a separate buffer to avoid recomputing or
# overwriting `positions_np` (used for RoPE / input token lookup).
self.kv_positions_cpu = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device="cpu",
pin_memory=self.pin_memory)
self.kv_positions_np = self.kv_positions_cpu.numpy()
self.query_start_loc_cpu = torch.zeros(self.max_num_reqs + 1,
dtype=torch.int32,
device="cpu",
......@@ -323,6 +339,34 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
device="cpu",
pin_memory=self.pin_memory)
self.seq_lens_np = self.seq_lens_cpu.numpy()
# KV compression metadata buffers (used by the "topk" policy).
# Per-token: whether this scheduled token must be kept in KV cache.
self.kv_compression_must_keep_cpu = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_must_keep_np = self.kv_compression_must_keep_cpu.numpy()
self.kv_compression_must_keep = torch.zeros(
self.max_num_tokens,
dtype=torch.bool,
device=self.device,
)
# Per-request: how many additional prompt tokens to keep among
# non-protected candidates (budget from env; selection uses scores).
self.kv_compression_topk_budget_cpu = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device="cpu",
pin_memory=self.pin_memory,
)
self.kv_compression_topk_budget_np = self.kv_compression_topk_budget_cpu.numpy()
self.kv_compression_topk_budget = torch.zeros(
self.max_num_reqs,
dtype=torch.int32,
device=self.device,
)
# Layer pairings for cross-layer KV sharing.
# If an Attention layer `layer_name` is in the keys of this dict, it
......@@ -448,6 +492,7 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
generator=generator,
block_ids=new_req_data.block_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
num_kv_tokens=new_req_data.num_kv_tokens,
output_token_ids=[],
lora_request=new_req_data.lora_request,
)
......@@ -497,11 +542,13 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
for i, req_id in enumerate(req_data.req_ids):
req_state = self.requests[req_id]
num_computed_tokens = req_data.num_computed_tokens[i]
num_kv_tokens = req_data.num_kv_tokens[i]
new_block_ids = req_data.new_block_ids[i]
resumed_from_preemption = req_data.resumed_from_preemption[i]
# Update the cached states.
req_state.num_computed_tokens = num_computed_tokens
req_state.num_kv_tokens = num_kv_tokens
spec_token_ids = (
scheduler_output.scheduled_spec_decode_tokens.get(req_id, ()))
......@@ -545,6 +592,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# Update the persistent batch.
self.input_batch.num_computed_tokens_cpu[req_index] = (
num_computed_tokens)
self.input_batch.num_kv_tokens_cpu[req_index] = num_kv_tokens
if resumed_from_preemption:
self.input_batch.block_table.add_row(new_block_ids, req_index)
else:
self.input_batch.block_table.append_row(new_block_ids, req_index)
# For the last rank, we don't need to update the token_ids_cpu
......@@ -658,6 +709,78 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
np.add(self.input_batch.num_computed_tokens_cpu[req_indices],
arange,
out=positions_np)
# KV positions (where the KV for each scheduled token is temporarily
# written). When KV compression is enabled, KV positions are decoupled
# from logical positions.
use_kv_compression = envs.VLLM_ENABLE_KV_COMPRESSION
if use_kv_compression:
kv_positions_np = self.kv_positions_np[:total_num_scheduled_tokens]
np.add(self.input_batch.num_kv_tokens_cpu[req_indices],
arange,
out=kv_positions_np)
else:
kv_positions_np = None
if use_kv_compression:
prompt_ratio = envs.VLLM_KV_COMPRESSION_PROMPT_RATIO
prompt_budget = envs.VLLM_KV_COMPRESSION_PROMPT_BUDGET
protected_prefix = envs.VLLM_KV_COMPRESSION_PROTECTED_PREFIX
protected_suffix = envs.VLLM_KV_COMPRESSION_PROTECTED_SUFFIX
keep_last = envs.VLLM_KV_COMPRESSION_KEEP_LAST_TOKEN
must_keep_np = self.kv_compression_must_keep_np[
:total_num_scheduled_tokens]
must_keep_np.fill(False)
topk_budget_np = self.kv_compression_topk_budget_np[:num_reqs]
topk_budget_np.fill(0)
for req_idx in range(num_reqs):
qlen = int(num_scheduled_tokens[req_idx])
if qlen <= 0:
continue
start = 0 if req_idx == 0 else int(cu_num_tokens[req_idx - 1])
end = int(cu_num_tokens[req_idx])
assert end - start == qlen
base_pos = int(
self.input_batch.num_computed_tokens_cpu[req_idx])
prompt_len = int(self.input_batch.num_prompt_tokens[req_idx])
end_pos = base_pos + qlen
pos = base_pos + np.arange(qlen, dtype=np.int64)
prompt_mask = pos < prompt_len
# Decode tokens are always kept.
must_keep = ~prompt_mask
if np.any(prompt_mask):
suffix_start = max(prompt_len - protected_suffix, 0)
must_keep |= prompt_mask & (pos < protected_prefix)
must_keep |= prompt_mask & (pos >= suffix_start)
if keep_last:
last = prompt_len - 1
if base_pos <= last < end_pos:
must_keep[last - base_pos] = True
topk_budget_np[req_idx] = compute_topk_budget_step(
prompt_len=prompt_len,
start_pos=base_pos,
end_pos=end_pos,
protected_prefix=protected_prefix,
protected_suffix=protected_suffix,
keep_last_token=keep_last,
prompt_ratio=prompt_ratio,
prompt_budget=prompt_budget,
)
must_keep_np[start:end] = must_keep
# Decode-only fast path: if all scheduled tokens are unconditionally
# kept and there is no Top-K budget, KV compaction is a no-op and we
# can skip score/topk/dst entirely in the attention backend.
self.kv_compression_needs_compaction = (not must_keep_np.all()) or (
topk_budget_np > 0).any()
else:
self.kv_compression_needs_compaction = False
# Calculate M-RoPE positions.
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
......@@ -685,6 +808,8 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
block_size = kv_cache_group_spec.kv_cache_spec.block_size
block_table: BlockTable = self.input_batch.block_table[
kv_cache_group_id]
slot_positions_np = (kv_positions_np
if use_kv_compression else positions_np)
# E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2]
# -> [0, 0, K, K, K + 1, K + 1, K + 2, 2 * K, 2 * K, 2 * K + 1]
# where K is the max_num_blocks_per_req and the block size is 2.
......@@ -693,11 +818,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
# block_size.
block_table_indices = (
req_indices * block_table.max_num_blocks_per_req +
positions_np // block_size)
slot_positions_np // block_size)
block_table_cpu = block_table.get_cpu_tensor()
block_numbers = block_table_cpu.flatten(
)[block_table_indices].numpy()
block_offsets = positions_np % block_size
block_offsets = slot_positions_np % block_size
np.add(
block_numbers * block_size,
block_offsets,
......@@ -707,6 +832,11 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_np[0] = 0
self.query_start_loc_np[1:num_reqs + 1] = cu_num_tokens
if use_kv_compression:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_kv_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
else:
self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens)
......@@ -729,6 +859,15 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
self.query_start_loc_cpu[:num_reqs + 1], non_blocking=True)
self.seq_lens[:num_reqs].copy_(self.seq_lens_cpu[:num_reqs],
non_blocking=True)
if use_kv_compression:
self.kv_compression_must_keep[:total_num_scheduled_tokens].copy_(
self.kv_compression_must_keep_cpu[:total_num_scheduled_tokens],
non_blocking=True,
)
self.kv_compression_topk_budget[:num_reqs].copy_(
self.kv_compression_topk_budget_cpu[:num_reqs],
non_blocking=True,
)
# Fill unused with -1. Needed for reshape_and_cache
self.seq_lens[num_reqs:].fill_(0)
......@@ -2532,6 +2671,10 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
assert len(self.attn_backends) == 0 and len(
self.attn_metadata_builders
) == 0, "Attention backends are already initialized"
if envs.VLLM_ENABLE_KV_COMPRESSION and self.full_cuda_graph:
raise ValueError(
"KV compression is currently incompatible with full CUDA "
"graph mode.")
for i, kv_cache_group_spec in enumerate(
kv_cache_config.kv_cache_groups):
kv_cache_spec = kv_cache_group_spec.kv_cache_spec
......@@ -2555,7 +2698,16 @@ class GPUModelRunnerBase(LoRAModelRunnerMixin):
raise NotImplementedError(
"Non-Attention backend is not supported by V1 "
"GPUModelRunner.")
if (envs.VLLM_ENABLE_KV_COMPRESSION
and attn_backend_i.get_name() != "FLASH_ATTN_VLLM_V1"):
raise ValueError(
"KV compression currently requires "
"VLLM_ATTENTION_BACKEND=FLASH_ATTN_VLLM_V1.")
elif isinstance(kv_cache_spec, MambaSpec):
if envs.VLLM_ENABLE_KV_COMPRESSION:
raise ValueError(
"KV compression is currently only supported for "
"Transformer attention layers.")
attn_backend_i = Mamba2AttentionBackend
else:
raise ValueError(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment